6 #include "tcpiohandler.hh"
8 const bool TCPIOHandler::s_disableConnectForUnitTests
= false;
12 #endif /* HAVE_LIBSODIUM */
14 #if defined(HAVE_DNS_OVER_TLS) || defined(HAVE_DNS_OVER_HTTPS)
17 #include <openssl/conf.h>
18 #include <openssl/err.h>
19 #include <openssl/rand.h>
20 #include <openssl/ssl.h>
21 #include <openssl/x509v3.h>
26 class OpenSSLFrontendContext
29 OpenSSLFrontendContext(const ComboAddress
& addr
, const TLSConfig
& tlsConfig
): d_ticketKeys(tlsConfig
.d_numberOfTicketsKeys
)
31 registerOpenSSLUser();
33 auto [ctx
, warnings
] = libssl_init_server_context(tlsConfig
, d_ocspResponses
);
34 for (const auto& warning
: warnings
) {
35 warnlog("%s", warning
);
37 d_tlsCtx
= std::move(ctx
);
40 ERR_print_errors_fp(stderr
);
41 throw std::runtime_error("Error creating TLS context on " + addr
.toStringWithPort());
49 unregisterOpenSSLUser();
52 OpenSSLTLSTicketKeysRing d_ticketKeys
;
53 std::map
<int, std::string
> d_ocspResponses
;
54 std::unique_ptr
<SSL_CTX
, void(*)(SSL_CTX
*)> d_tlsCtx
{nullptr, SSL_CTX_free
};
55 pdns::UniqueFilePtr d_keyLogFile
{nullptr};
58 class OpenSSLSession
: public TLSSession
61 OpenSSLSession(std::unique_ptr
<SSL_SESSION
, void(*)(SSL_SESSION
*)>&& sess
): d_sess(std::move(sess
))
65 std::unique_ptr
<SSL_SESSION
, void(*)(SSL_SESSION
*)> getNative()
67 return std::move(d_sess
);
71 std::unique_ptr
<SSL_SESSION
, void(*)(SSL_SESSION
*)> d_sess
;
74 class OpenSSLTLSConnection
: public TLSConnection
77 /* server side connection */
78 OpenSSLTLSConnection(int socket
, const struct timeval
& timeout
, std::shared_ptr
<OpenSSLFrontendContext
> feContext
): d_feContext(std::move(feContext
)), d_conn(std::unique_ptr
<SSL
, void(*)(SSL
*)>(SSL_new(d_feContext
->d_tlsCtx
.get()), SSL_free
)), d_timeout(timeout
)
83 vinfolog("Error creating TLS object");
85 ERR_print_errors_fp(stderr
);
87 throw std::runtime_error("Error creating TLS object");
90 if (!SSL_set_fd(d_conn
.get(), d_socket
)) {
91 throw std::runtime_error("Error assigning socket");
94 SSL_set_ex_data(d_conn
.get(), getConnectionIndex(), this);
97 /* client-side connection */
98 OpenSSLTLSConnection(const std::string
& hostname
, bool hostIsAddr
, int socket
, const struct timeval
& timeout
, std::shared_ptr
<SSL_CTX
>& tlsCtx
): d_tlsCtx(tlsCtx
), d_conn(std::unique_ptr
<SSL
, void(*)(SSL
*)>(SSL_new(tlsCtx
.get()), SSL_free
)), d_hostname(hostname
), d_timeout(timeout
)
103 vinfolog("Error creating TLS object");
105 ERR_print_errors_fp(stderr
);
107 throw std::runtime_error("Error creating TLS object");
110 if (!SSL_set_fd(d_conn
.get(), d_socket
)) {
111 throw std::runtime_error("Error assigning socket");
114 /* set outgoing Server Name Indication */
115 if (!d_hostname
.empty() && SSL_set_tlsext_host_name(d_conn
.get(), d_hostname
.c_str()) != 1) {
116 throw std::runtime_error("Error setting TLS SNI to " + d_hostname
);
120 #if (OPENSSL_VERSION_NUMBER >= 0x10002000L)
121 X509_VERIFY_PARAM
*param
= SSL_get0_param(d_conn
.get());
122 /* Enable automatic IP checks */
123 X509_VERIFY_PARAM_set_hostflags(param
, X509_CHECK_FLAG_NO_PARTIAL_WILDCARDS
);
124 if (X509_VERIFY_PARAM_set1_ip_asc(param
, d_hostname
.c_str()) != 1) {
125 throw std::runtime_error("Error setting TLS IP for certificate validation");
128 /* no validation for you, see https://wiki.openssl.org/index.php/Hostname_validation */
132 #if (OPENSSL_VERSION_NUMBER >= 0x1010000fL) && defined(HAVE_SSL_SET_HOSTFLAGS) // grrr libressl
133 SSL_set_hostflags(d_conn
.get(), X509_CHECK_FLAG_NO_PARTIAL_WILDCARDS
);
134 if (SSL_set1_host(d_conn
.get(), d_hostname
.c_str()) != 1) {
135 throw std::runtime_error("Error setting TLS hostname for certificate validation");
137 #elif (OPENSSL_VERSION_NUMBER >= 0x10002000L)
138 X509_VERIFY_PARAM
*param
= SSL_get0_param(d_conn
.get());
139 /* Enable automatic hostname checks */
140 X509_VERIFY_PARAM_set_hostflags(param
, X509_CHECK_FLAG_NO_PARTIAL_WILDCARDS
);
141 if (X509_VERIFY_PARAM_set1_host(param
, d_hostname
.c_str(), d_hostname
.size()) != 1) {
142 throw std::runtime_error("Error setting TLS hostname for certificate validation");
145 /* no hostname validation for you, see https://wiki.openssl.org/index.php/Hostname_validation */
149 SSL_set_ex_data(d_conn
.get(), getConnectionIndex(), this);
152 std::vector
<int> getAsyncFDs() override
154 std::vector
<int> results
;
155 #ifdef SSL_MODE_ASYNC
156 if (SSL_waiting_for_async(d_conn
.get()) != 1) {
160 OSSL_ASYNC_FD fds
[32];
161 size_t numfds
= sizeof(fds
)/sizeof(*fds
);
162 SSL_get_all_async_fds(d_conn
.get(), nullptr, &numfds
);
167 SSL_get_all_async_fds(d_conn
.get(), fds
, &numfds
);
168 results
.reserve(numfds
);
169 for (size_t idx
= 0; idx
< numfds
; idx
++) {
170 results
.push_back(fds
[idx
]);
176 IOState
convertIORequestToIOState(int res
) const
178 int error
= SSL_get_error(d_conn
.get(), res
);
179 if (error
== SSL_ERROR_WANT_READ
) {
180 return IOState::NeedRead
;
182 else if (error
== SSL_ERROR_WANT_WRITE
) {
183 return IOState::NeedWrite
;
185 else if (error
== SSL_ERROR_SYSCALL
) {
187 throw std::runtime_error("TLS connection closed by remote end");
190 throw std::runtime_error("Syscall error while processing TLS connection: " + std::string(strerror(errno
)));
193 else if (error
== SSL_ERROR_ZERO_RETURN
) {
194 throw std::runtime_error("TLS connection closed by remote end");
196 #ifdef SSL_MODE_ASYNC
197 else if (error
== SSL_ERROR_WANT_ASYNC
) {
198 return IOState::Async
;
203 throw std::runtime_error("Error while processing TLS connection: (" + std::to_string(error
) + ") " + libssl_get_error_string());
205 throw std::runtime_error("Error while processing TLS connection: " + std::to_string(error
));
210 void handleIORequest(int res
, const struct timeval
& timeout
)
212 auto state
= convertIORequestToIOState(res
);
213 if (state
== IOState::NeedRead
) {
214 res
= waitForData(d_socket
, timeout
.tv_sec
, timeout
.tv_usec
);
216 throw std::runtime_error("Timeout while reading from TLS connection");
219 throw std::runtime_error("Error waiting to read from TLS connection");
222 else if (state
== IOState::NeedWrite
) {
223 res
= waitForRWData(d_socket
, false, timeout
.tv_sec
, timeout
.tv_usec
);
225 throw std::runtime_error("Timeout while writing to TLS connection");
228 throw std::runtime_error("Error waiting to write to TLS connection");
233 IOState
tryConnect(bool fastOpen
, const ComboAddress
& remote
) override
239 int res
= SSL_connect(d_conn
.get());
241 return IOState::Done
;
244 return convertIORequestToIOState(res
);
247 throw std::runtime_error("Error establishing a TLS connection");
250 void connect(bool fastOpen
, const ComboAddress
& remote
, const struct timeval
&timeout
) override
256 struct timeval start
{0,0};
257 struct timeval remainingTime
= timeout
;
258 if (timeout
.tv_sec
!= 0 || timeout
.tv_usec
!= 0) {
259 gettimeofday(&start
, nullptr);
264 res
= SSL_connect(d_conn
.get());
266 handleIORequest(res
, remainingTime
);
269 if (timeout
.tv_sec
!= 0 || timeout
.tv_usec
!= 0) {
271 gettimeofday(&now
, nullptr);
272 struct timeval elapsed
= now
- start
;
273 if (now
< start
|| remainingTime
< elapsed
) {
274 throw runtime_error("Timeout while establishing TLS connection");
277 remainingTime
= remainingTime
- elapsed
;
283 IOState
tryHandshake() override
286 /* In client mode, the handshake is initiated by the call to SSL_connect()
287 done from connect()/tryConnect().
288 In blocking mode it does not return before the handshake has been finished,
289 and in non-blocking mode calling SSL_connect() once is enough for SSL_write()
290 and SSL_read() to transparently continue to negotiate the connection after that
291 (equivalent to doing SSL_set_connect_state() plus trying to write).
293 return IOState::Done
;
296 /* As explained above in the client-mode block, we only need to call SSL_accept() once
297 for SSL_write() and SSL_read() to transparently continue to negotiate the connection after that.
298 It is equivalent to calling SSL_set_accept_state() plus trying to read.
300 int res
= SSL_accept(d_conn
.get());
302 return IOState::Done
;
305 return convertIORequestToIOState(res
);
308 throw std::runtime_error("Error accepting TLS connection");
311 void doHandshake() override
314 /* we are a client, nothing to do, see the non-blocking version */
320 res
= SSL_accept(d_conn
.get());
322 handleIORequest(res
, d_timeout
);
328 throw std::runtime_error("Error accepting TLS connection");
332 IOState
tryWrite(const PacketBuffer
& buffer
, size_t& pos
, size_t toWrite
) override
334 if (!d_feContext
&& !d_connected
) {
336 /* work-around to get kTLS to be started, as we cannot do that until after the socket has been connected */
337 SSL_set_fd(d_conn
.get(), SSL_get_fd(d_conn
.get()));
342 int res
= SSL_write(d_conn
.get(), reinterpret_cast<const char *>(&buffer
.at(pos
)), static_cast<int>(toWrite
- pos
));
344 return convertIORequestToIOState(res
);
347 pos
+= static_cast<size_t>(res
);
350 while (pos
< toWrite
);
356 return IOState::Done
;
359 IOState
tryRead(PacketBuffer
& buffer
, size_t& pos
, size_t toRead
, bool allowIncomplete
) override
362 int res
= SSL_read(d_conn
.get(), reinterpret_cast<char *>(&buffer
.at(pos
)), static_cast<int>(toRead
- pos
));
364 return convertIORequestToIOState(res
);
367 pos
+= static_cast<size_t>(res
);
368 if (allowIncomplete
) {
373 while (pos
< toRead
);
374 return IOState::Done
;
377 size_t read(void* buffer
, size_t bufferSize
, const struct timeval
& readTimeout
, const struct timeval
& totalTimeout
, bool allowIncomplete
) override
380 struct timeval start
= {0, 0};
381 struct timeval remainingTime
= totalTimeout
;
382 if (totalTimeout
.tv_sec
!= 0 || totalTimeout
.tv_usec
!= 0) {
383 gettimeofday(&start
, nullptr);
387 int res
= SSL_read(d_conn
.get(), (reinterpret_cast<char *>(buffer
) + got
), static_cast<int>(bufferSize
- got
));
389 handleIORequest(res
, readTimeout
);
392 got
+= static_cast<size_t>(res
);
393 if (allowIncomplete
) {
398 if (totalTimeout
.tv_sec
!= 0 || totalTimeout
.tv_usec
!= 0) {
400 gettimeofday(&now
, nullptr);
401 struct timeval elapsed
= now
- start
;
402 if (now
< start
|| remainingTime
< elapsed
) {
403 throw runtime_error("Timeout while reading data");
406 remainingTime
= remainingTime
- elapsed
;
409 while (got
< bufferSize
);
414 size_t write(const void* buffer
, size_t bufferSize
, const struct timeval
& writeTimeout
) override
418 int res
= SSL_write(d_conn
.get(), (reinterpret_cast<const char *>(buffer
) + got
), static_cast<int>(bufferSize
- got
));
420 handleIORequest(res
, writeTimeout
);
423 got
+= static_cast<size_t>(res
);
426 while (got
< bufferSize
);
431 bool isUsable() const override
438 int res
= SSL_peek(d_conn
.get(), &buf
, sizeof(buf
));
443 convertIORequestToIOState(res
);
453 void close() override
456 SSL_shutdown(d_conn
.get());
460 std::string
getServerNameIndication() const override
463 const char* value
= SSL_get_servername(d_conn
.get(), TLSEXT_NAMETYPE_host_name
);
465 return std::string(value
);
468 return std::string();
471 std::vector
<uint8_t> getNextProtocol() const override
473 std::vector
<uint8_t> result
;
478 const unsigned char* alpn
= nullptr;
479 unsigned int alpnLen
= 0;
481 #ifdef HAVE_SSL_GET0_NEXT_PROTO_NEGOTIATED
482 SSL_get0_next_proto_negotiated(d_conn
.get(), &alpn
, &alpnLen
);
483 #endif /* HAVE_SSL_GET0_NEXT_PROTO_NEGOTIATED */
484 #endif /* DISABLE_NPN */
485 #ifdef HAVE_SSL_GET0_ALPN_SELECTED
486 if (alpn
== nullptr) {
487 SSL_get0_alpn_selected(d_conn
.get(), &alpn
, &alpnLen
);
489 #endif /* HAVE_SSL_GET0_ALPN_SELECTED */
490 if (alpn
!= nullptr && alpnLen
> 0) {
491 result
.insert(result
.end(), alpn
, alpn
+ alpnLen
);
496 LibsslTLSVersion
getTLSVersion() const override
498 auto proto
= SSL_version(d_conn
.get());
501 return LibsslTLSVersion::TLS10
;
503 return LibsslTLSVersion::TLS11
;
505 return LibsslTLSVersion::TLS12
;
506 #ifdef TLS1_3_VERSION
508 return LibsslTLSVersion::TLS13
;
509 #endif /* TLS1_3_VERSION */
511 return LibsslTLSVersion::Unknown
;
515 bool hasSessionBeenResumed() const override
518 return SSL_session_reused(d_conn
.get()) != 0;
523 std::vector
<std::unique_ptr
<TLSSession
>> getSessions() override
525 return std::move(d_tlsSessions
);
528 void setSession(std::unique_ptr
<TLSSession
>& session
) override
530 auto sess
= dynamic_cast<OpenSSLSession
*>(session
.get());
532 throw std::runtime_error("Unable to convert OpenSSL session");
535 auto native
= sess
->getNative();
536 auto ret
= SSL_set_session(d_conn
.get(), native
.get());
538 throw std::runtime_error("Error setting up session: " + libssl_get_error_string());
543 void addNewTicket(SSL_SESSION
* session
)
545 d_tlsSessions
.push_back(std::make_unique
<OpenSSLSession
>(std::unique_ptr
<SSL_SESSION
, void (*)(SSL_SESSION
*)>(session
, SSL_SESSION_free
)));
553 static void generateConnectionIndexIfNeeded()
555 auto init
= s_initTLSConnIndex
.lock();
560 /* not initialized yet */
561 s_tlsConnIndex
= SSL_get_ex_new_index(0, nullptr, nullptr, nullptr, nullptr);
562 if (s_tlsConnIndex
== -1) {
563 throw std::runtime_error("Error getting an index for TLS connection data");
569 static int getConnectionIndex()
571 return s_tlsConnIndex
;
575 static LockGuarded
<bool> s_initTLSConnIndex
;
576 static int s_tlsConnIndex
;
577 std::vector
<std::unique_ptr
<TLSSession
>> d_tlsSessions
;
579 std::shared_ptr
<OpenSSLFrontendContext
> d_feContext
;
581 std::shared_ptr
<SSL_CTX
> d_tlsCtx
;
582 std::unique_ptr
<SSL
, void(*)(SSL
*)> d_conn
;
583 std::string d_hostname
;
584 struct timeval d_timeout
;
585 bool d_connected
{false};
589 LockGuarded
<bool> OpenSSLTLSConnection::s_initTLSConnIndex
{false};
590 int OpenSSLTLSConnection::s_tlsConnIndex
{-1};
592 class OpenSSLTLSIOCtx
: public TLSCtx
595 /* server side context */
596 OpenSSLTLSIOCtx(TLSFrontend
& fe
): d_feContext(std::make_shared
<OpenSSLFrontendContext
>(fe
.d_addr
, fe
.d_tlsConfig
))
598 OpenSSLTLSConnection::generateConnectionIndexIfNeeded();
600 d_ticketsKeyRotationDelay
= fe
.d_tlsConfig
.d_ticketsKeyRotationDelay
;
602 if (fe
.d_tlsConfig
.d_enableTickets
&& fe
.d_tlsConfig
.d_numberOfTicketsKeys
> 0) {
603 /* use our own ticket keys handler so we can rotate them */
604 #if OPENSSL_VERSION_MAJOR >= 3
605 SSL_CTX_set_tlsext_ticket_key_evp_cb(d_feContext
->d_tlsCtx
.get(), &OpenSSLTLSIOCtx::ticketKeyCb
);
607 SSL_CTX_set_tlsext_ticket_key_cb(d_feContext
->d_tlsCtx
.get(), &OpenSSLTLSIOCtx::ticketKeyCb
);
609 libssl_set_ticket_key_callback_data(d_feContext
->d_tlsCtx
.get(), d_feContext
.get());
612 #ifndef DISABLE_OCSP_STAPLING
613 if (!d_feContext
->d_ocspResponses
.empty()) {
614 SSL_CTX_set_tlsext_status_cb(d_feContext
->d_tlsCtx
.get(), &OpenSSLTLSIOCtx::ocspStaplingCb
);
615 SSL_CTX_set_tlsext_status_arg(d_feContext
->d_tlsCtx
.get(), &d_feContext
->d_ocspResponses
);
617 #endif /* DISABLE_OCSP_STAPLING */
619 if (fe
.d_tlsConfig
.d_readAhead
) {
620 SSL_CTX_set_read_ahead(d_feContext
->d_tlsCtx
.get(), 1);
623 libssl_set_error_counters_callback(d_feContext
->d_tlsCtx
, &fe
.d_tlsCounters
);
625 if (!fe
.d_tlsConfig
.d_keyLogFile
.empty()) {
626 d_feContext
->d_keyLogFile
= libssl_set_key_log_file(d_feContext
->d_tlsCtx
, fe
.d_tlsConfig
.d_keyLogFile
);
630 if (fe
.d_tlsConfig
.d_ticketKeyFile
.empty()) {
631 handleTicketsKeyRotation(time(nullptr));
634 OpenSSLTLSIOCtx::loadTicketsKeys(fe
.d_tlsConfig
.d_ticketKeyFile
);
637 catch (const std::exception
& e
) {
642 /* client side context */
643 OpenSSLTLSIOCtx(const TLSContextParameters
& params
)
648 SSL_OP_NO_COMPRESSION
|
649 SSL_OP_NO_SESSION_RESUMPTION_ON_RENEGOTIATION
|
650 SSL_OP_SINGLE_DH_USE
|
651 SSL_OP_SINGLE_ECDH_USE
|
652 #ifdef SSL_OP_IGNORE_UNEXPECTED_EOF
653 SSL_OP_IGNORE_UNEXPECTED_EOF
|
655 SSL_OP_CIPHER_SERVER_PREFERENCE
;
656 if (!params
.d_enableRenegotiation
) {
657 #ifdef SSL_OP_NO_RENEGOTIATION
658 sslOptions
|= SSL_OP_NO_RENEGOTIATION
;
659 #elif defined(SSL_OP_NO_CLIENT_RENEGOTIATION)
660 sslOptions
|= SSL_OP_NO_CLIENT_RENEGOTIATION
;
665 #ifdef SSL_OP_ENABLE_KTLS
666 sslOptions
|= SSL_OP_ENABLE_KTLS
;
668 #endif /* SSL_OP_ENABLE_KTLS */
671 registerOpenSSLUser();
673 OpenSSLTLSConnection::generateConnectionIndexIfNeeded();
675 #ifdef HAVE_TLS_CLIENT_METHOD
676 d_tlsCtx
= std::shared_ptr
<SSL_CTX
>(SSL_CTX_new(TLS_client_method()), SSL_CTX_free
);
678 d_tlsCtx
= std::shared_ptr
<SSL_CTX
>(SSL_CTX_new(SSLv23_client_method()), SSL_CTX_free
);
681 ERR_print_errors_fp(stderr
);
682 throw std::runtime_error("Error creating TLS context");
685 SSL_CTX_set_options(d_tlsCtx
.get(), sslOptions
);
686 #if defined(SSL_CTX_set_ecdh_auto)
687 SSL_CTX_set_ecdh_auto(d_tlsCtx
.get(), 1);
690 if (!params
.d_ciphers
.empty()) {
691 if (SSL_CTX_set_cipher_list(d_tlsCtx
.get(), params
.d_ciphers
.c_str()) != 1) {
692 ERR_print_errors_fp(stderr
);
693 throw std::runtime_error("Error setting the cipher list to '" + params
.d_ciphers
+ "' for the TLS context");
696 #ifdef HAVE_SSL_CTX_SET_CIPHERSUITES
697 if (!params
.d_ciphers13
.empty()) {
698 if (SSL_CTX_set_ciphersuites(d_tlsCtx
.get(), params
.d_ciphers13
.c_str()) != 1) {
699 ERR_print_errors_fp(stderr
);
700 throw std::runtime_error("Error setting the TLS 1.3 cipher list to '" + params
.d_ciphers13
+ "' for the TLS context");
703 #endif /* HAVE_SSL_CTX_SET_CIPHERSUITES */
705 if (params
.d_validateCertificates
) {
706 if (params
.d_caStore
.empty()) {
707 if (SSL_CTX_set_default_verify_paths(d_tlsCtx
.get()) != 1) {
708 throw std::runtime_error("Error adding the system's default trusted CAs");
711 if (SSL_CTX_load_verify_locations(d_tlsCtx
.get(), params
.d_caStore
.c_str(), nullptr) != 1) {
712 throw std::runtime_error("Error adding the trusted CAs file " + params
.d_caStore
);
716 SSL_CTX_set_verify(d_tlsCtx
.get(), SSL_VERIFY_PEER
, nullptr);
717 #if (OPENSSL_VERSION_NUMBER < 0x10002000L)
718 warnlog("TLS hostname validation requested but not supported for OpenSSL < 1.0.2");
722 /* we need to set SSL_SESS_CACHE_CLIENT for the "new ticket" callback (below) to be called,
723 but we don't want OpenSSL to cache the session itself so we set SSL_SESS_CACHE_NO_INTERNAL_STORE as well */
724 SSL_CTX_set_session_cache_mode(d_tlsCtx
.get(), SSL_SESS_CACHE_CLIENT
| SSL_SESS_CACHE_NO_INTERNAL_STORE
);
725 SSL_CTX_sess_set_new_cb(d_tlsCtx
.get(), &OpenSSLTLSIOCtx::newTicketFromServerCb
);
727 #ifdef SSL_MODE_RELEASE_BUFFERS
728 if (params
.d_releaseBuffers
) {
729 SSL_CTX_set_mode(d_tlsCtx
.get(), SSL_MODE_RELEASE_BUFFERS
);
734 ~OpenSSLTLSIOCtx() override
737 unregisterOpenSSLUser();
740 #if OPENSSL_VERSION_MAJOR >= 3
741 static int ticketKeyCb(SSL
* s
, unsigned char keyName
[TLS_TICKETS_KEY_NAME_SIZE
], unsigned char* iv
, EVP_CIPHER_CTX
* ectx
, EVP_MAC_CTX
* hctx
, int enc
)
743 static int ticketKeyCb(SSL
* s
, unsigned char keyName
[TLS_TICKETS_KEY_NAME_SIZE
], unsigned char* iv
, EVP_CIPHER_CTX
* ectx
, HMAC_CTX
* hctx
, int enc
)
746 auto* ctx
= reinterpret_cast<OpenSSLFrontendContext
*>(libssl_get_ticket_key_callback_data(s
));
747 if (ctx
== nullptr) {
751 int ret
= libssl_ticket_key_callback(s
, ctx
->d_ticketKeys
, keyName
, iv
, ectx
, hctx
, enc
);
753 if (ret
== 0 || ret
== 2) {
754 auto* conn
= reinterpret_cast<OpenSSLTLSConnection
*>(SSL_get_ex_data(s
, OpenSSLTLSConnection::getConnectionIndex()));
755 if (conn
!= nullptr) {
757 conn
->setUnknownTicketKey();
760 conn
->setResumedFromInactiveTicketKey();
769 #ifndef DISABLE_OCSP_STAPLING
770 static int ocspStaplingCb(SSL
* ssl
, void* arg
)
772 if (ssl
== nullptr || arg
== nullptr) {
773 return SSL_TLSEXT_ERR_NOACK
;
775 const auto ocspMap
= reinterpret_cast<std::map
<int, std::string
>*>(arg
);
776 return libssl_ocsp_stapling_callback(ssl
, *ocspMap
);
778 #endif /* DISABLE_OCSP_STAPLING */
780 static int newTicketFromServerCb(SSL
* ssl
, SSL_SESSION
* session
)
782 OpenSSLTLSConnection
* conn
= reinterpret_cast<OpenSSLTLSConnection
*>(SSL_get_ex_data(ssl
, OpenSSLTLSConnection::getConnectionIndex()));
783 if (session
== nullptr || conn
== nullptr) {
787 conn
->addNewTicket(session
);
791 std::unique_ptr
<TLSConnection
> getConnection(int socket
, const struct timeval
& timeout
, time_t now
) override
793 handleTicketsKeyRotation(now
);
795 return std::make_unique
<OpenSSLTLSConnection
>(socket
, timeout
, d_feContext
);
798 std::unique_ptr
<TLSConnection
> getClientConnection(const std::string
& host
, bool hostIsAddr
, int socket
, const struct timeval
& timeout
) override
800 auto conn
= std::make_unique
<OpenSSLTLSConnection
>(host
, hostIsAddr
, socket
, timeout
, d_tlsCtx
);
807 void rotateTicketsKey(time_t now
) override
809 d_feContext
->d_ticketKeys
.rotateTicketsKey(now
);
811 if (d_ticketsKeyRotationDelay
> 0) {
812 d_ticketsKeyNextRotation
= now
+ d_ticketsKeyRotationDelay
;
816 void loadTicketsKeys(const std::string
& keyFile
) final
818 d_feContext
->d_ticketKeys
.loadTicketsKeys(keyFile
);
820 if (d_ticketsKeyRotationDelay
> 0) {
821 d_ticketsKeyNextRotation
= time(nullptr) + d_ticketsKeyRotationDelay
;
825 size_t getTicketsKeysCount() override
827 return d_feContext
->d_ticketKeys
.getKeysCount();
830 std::string
getName() const override
835 bool setALPNProtos(const std::vector
<std::vector
<uint8_t>>& protos
) override
837 if (d_feContext
&& d_feContext
->d_tlsCtx
) {
838 d_alpnProtos
= protos
;
839 libssl_set_alpn_select_callback(d_feContext
->d_tlsCtx
.get(), alpnServerSelectCallback
, this);
843 return libssl_set_alpn_protos(d_tlsCtx
.get(), protos
);
849 bool setNextProtocolSelectCallback(bool(*cb
)(unsigned char** out
, unsigned char* outlen
, const unsigned char* in
, unsigned int inlen
)) override
851 d_nextProtocolSelectCallback
= cb
;
852 libssl_set_npn_select_callback(d_tlsCtx
.get(), npnSelectCallback
, this);
855 #endif /* DISABLE_NPN */
858 /* called in a client context, if the client advertised more than one ALPN values and the server returned more than one as well, to select the one to use. */
860 static int npnSelectCallback(SSL
* /* s */, unsigned char** out
, unsigned char* outlen
, const unsigned char* in
, unsigned int inlen
, void* arg
)
863 return SSL_TLSEXT_ERR_ALERT_WARNING
;
865 OpenSSLTLSIOCtx
* obj
= reinterpret_cast<OpenSSLTLSIOCtx
*>(arg
);
866 if (obj
->d_nextProtocolSelectCallback
) {
867 return (*obj
->d_nextProtocolSelectCallback
)(out
, outlen
, in
, inlen
) ? SSL_TLSEXT_ERR_OK
: SSL_TLSEXT_ERR_ALERT_WARNING
;
870 return SSL_TLSEXT_ERR_OK
;
874 static int alpnServerSelectCallback(SSL
*, const unsigned char** out
, unsigned char* outlen
, const unsigned char* in
, unsigned int inlen
, void* arg
)
877 return SSL_TLSEXT_ERR_ALERT_WARNING
;
879 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast): OpenSSL's API
880 OpenSSLTLSIOCtx
* obj
= reinterpret_cast<OpenSSLTLSIOCtx
*>(arg
);
882 const pdns::views::UnsignedCharView
inView(in
, inlen
);
883 // Server preference algorithm as per RFC 7301 section 3.2
884 for (const auto& tentative
: obj
->d_alpnProtos
) {
886 while (pos
< inView
.size()) {
887 size_t protoLen
= inView
.at(pos
);
889 if (protoLen
> (inlen
- pos
)) {
890 /* something is very wrong */
891 return SSL_TLSEXT_ERR_ALERT_WARNING
;
894 if (tentative
.size() == protoLen
&& memcmp(&inView
.at(pos
), tentative
.data(), tentative
.size()) == 0) {
895 *out
= &inView
.at(pos
);
897 return SSL_TLSEXT_ERR_OK
;
903 return SSL_TLSEXT_ERR_NOACK
;
906 std::vector
<std::vector
<uint8_t>> d_alpnProtos
; // store the supported ALPN protocols, so that the server can select based on what the client sent
907 std::shared_ptr
<OpenSSLFrontendContext
> d_feContext
{nullptr};
908 std::shared_ptr
<SSL_CTX
> d_tlsCtx
{nullptr}; // client context, on a server-side the context is stored in d_feContext->d_tlsCtx
909 bool (*d_nextProtocolSelectCallback
)(unsigned char** out
, unsigned char* outlen
, const unsigned char* in
, unsigned int inlen
){nullptr};
913 #endif /* HAVE_LIBSSL */
916 #include <gnutls/gnutls.h>
917 #include <gnutls/x509.h>
919 static void safe_memory_lock(void* data
, size_t size
)
921 #ifdef HAVE_LIBSODIUM
922 sodium_mlock(data
, size
);
926 static void safe_memory_release(void* data
, size_t size
)
928 #ifdef HAVE_LIBSODIUM
929 sodium_munlock(data
, size
);
930 #elif defined(HAVE_EXPLICIT_BZERO)
931 explicit_bzero(data
, size
);
932 #elif defined(HAVE_EXPLICIT_MEMSET)
933 explicit_memset(data
, 0, size
);
934 #elif defined(HAVE_GNUTLS_MEMSET)
935 gnutls_memset(data
, 0, size
);
937 /* shamelessly taken from Dovecot's src/lib/safe-memset.c */
938 volatile unsigned int volatile_zero_idx
= 0;
939 volatile unsigned char *p
= reinterpret_cast<volatile unsigned char *>(data
);
945 memset(data
, 0, size
);
946 } while (p
[volatile_zero_idx
] != 0);
950 class GnuTLSTicketsKey
955 if (gnutls_session_ticket_key_generate(&d_key
) != GNUTLS_E_SUCCESS
) {
956 throw std::runtime_error("Error generating tickets key for TLS context");
959 safe_memory_lock(d_key
.data
, d_key
.size
);
962 GnuTLSTicketsKey(const std::string
& keyFile
)
964 /* to be sure we are loading the correct amount of data, which
965 may change between versions, let's generate a correct key first */
966 if (gnutls_session_ticket_key_generate(&d_key
) != GNUTLS_E_SUCCESS
) {
967 throw std::runtime_error("Error generating tickets key (before parsing key file) for TLS context");
970 safe_memory_lock(d_key
.data
, d_key
.size
);
973 ifstream
file(keyFile
);
974 file
.read(reinterpret_cast<char*>(d_key
.data
), d_key
.size
);
978 throw std::runtime_error("Invalid GnuTLS tickets key file " + keyFile
);
983 catch (const std::exception
& e
) {
984 safe_memory_release(d_key
.data
, d_key
.size
);
985 gnutls_free(d_key
.data
);
986 d_key
.data
= nullptr;
993 if (d_key
.data
!= nullptr && d_key
.size
> 0) {
994 safe_memory_release(d_key
.data
, d_key
.size
);
996 gnutls_free(d_key
.data
);
997 d_key
.data
= nullptr;
999 const gnutls_datum_t
& getKey() const
1005 gnutls_datum_t d_key
{nullptr, 0};
1008 class GnuTLSSession
: public TLSSession
1011 GnuTLSSession(gnutls_datum_t
& sess
): d_sess(sess
)
1013 sess
.data
= nullptr;
1017 ~GnuTLSSession() override
1019 if (d_sess
.data
!= nullptr && d_sess
.size
> 0) {
1020 safe_memory_release(d_sess
.data
, d_sess
.size
);
1022 gnutls_free(d_sess
.data
);
1023 d_sess
.data
= nullptr;
1026 const gnutls_datum_t
& getNative()
1032 gnutls_datum_t d_sess
{nullptr, 0};
1035 class GnuTLSConnection
: public TLSConnection
1038 /* server side connection */
1039 GnuTLSConnection(int socket
, const struct timeval
& timeout
, std::shared_ptr
<gnutls_certificate_credentials_st
>& creds
, const gnutls_priority_t priorityCache
, std::shared_ptr
<GnuTLSTicketsKey
>& ticketsKey
, bool enableTickets
): d_creds(creds
), d_ticketsKey(ticketsKey
), d_conn(std::unique_ptr
<gnutls_session_int
, void(*)(gnutls_session_t
)>(nullptr, gnutls_deinit
))
1041 unsigned int sslOptions
= GNUTLS_SERVER
| GNUTLS_NONBLOCK
;
1042 #ifdef GNUTLS_NO_SIGNAL
1043 sslOptions
|= GNUTLS_NO_SIGNAL
;
1048 gnutls_session_t conn
;
1049 if (gnutls_init(&conn
, sslOptions
) != GNUTLS_E_SUCCESS
) {
1050 throw std::runtime_error("Error creating TLS connection");
1053 d_conn
= std::unique_ptr
<gnutls_session_int
, void(*)(gnutls_session_t
)>(conn
, gnutls_deinit
);
1056 if (gnutls_credentials_set(d_conn
.get(), GNUTLS_CRD_CERTIFICATE
, d_creds
.get()) != GNUTLS_E_SUCCESS
) {
1057 throw std::runtime_error("Error setting certificate and key to TLS connection");
1060 if (gnutls_priority_set(d_conn
.get(), priorityCache
) != GNUTLS_E_SUCCESS
) {
1061 throw std::runtime_error("Error setting ciphers to TLS connection");
1064 if (enableTickets
&& d_ticketsKey
) {
1065 const gnutls_datum_t
& key
= d_ticketsKey
->getKey();
1066 if (gnutls_session_ticket_enable_server(d_conn
.get(), &key
) != GNUTLS_E_SUCCESS
) {
1067 throw std::runtime_error("Error setting the tickets key to TLS connection");
1071 gnutls_transport_set_int(d_conn
.get(), d_socket
);
1073 /* timeouts are in milliseconds */
1074 gnutls_handshake_set_timeout(d_conn
.get(), timeout
.tv_sec
* 1000 + timeout
.tv_usec
/ 1000);
1075 gnutls_record_set_timeout(d_conn
.get(), timeout
.tv_sec
* 1000 + timeout
.tv_usec
/ 1000);
1078 /* client-side connection */
1079 GnuTLSConnection(const std::string
& host
, int socket
, const struct timeval
& timeout
, std::shared_ptr
<gnutls_certificate_credentials_st
>& creds
, const gnutls_priority_t priorityCache
, bool validateCerts
): d_creds(creds
), d_conn(std::unique_ptr
<gnutls_session_int
, void(*)(gnutls_session_t
)>(nullptr, gnutls_deinit
)), d_host(host
), d_client(true)
1081 unsigned int sslOptions
= GNUTLS_CLIENT
| GNUTLS_NONBLOCK
;
1082 #ifdef GNUTLS_NO_SIGNAL
1083 sslOptions
|= GNUTLS_NO_SIGNAL
;
1088 gnutls_session_t conn
;
1089 if (gnutls_init(&conn
, sslOptions
) != GNUTLS_E_SUCCESS
) {
1090 throw std::runtime_error("Error creating TLS connection");
1093 d_conn
= std::unique_ptr
<gnutls_session_int
, void(*)(gnutls_session_t
)>(conn
, gnutls_deinit
);
1096 int rc
= gnutls_credentials_set(d_conn
.get(), GNUTLS_CRD_CERTIFICATE
, d_creds
.get());
1097 if (rc
!= GNUTLS_E_SUCCESS
) {
1098 throw std::runtime_error("Error setting certificate and key to TLS connection: " + std::string(gnutls_strerror(rc
)));
1101 rc
= gnutls_priority_set(d_conn
.get(), priorityCache
);
1102 if (rc
!= GNUTLS_E_SUCCESS
) {
1103 throw std::runtime_error("Error setting ciphers to TLS connection: " + std::string(gnutls_strerror(rc
)));
1106 gnutls_transport_set_int(d_conn
.get(), d_socket
);
1108 /* timeouts are in milliseconds */
1109 gnutls_handshake_set_timeout(d_conn
.get(), timeout
.tv_sec
* 1000 + timeout
.tv_usec
/ 1000);
1110 gnutls_record_set_timeout(d_conn
.get(), timeout
.tv_sec
* 1000 + timeout
.tv_usec
/ 1000);
1112 #ifdef HAVE_GNUTLS_SESSION_SET_VERIFY_CERT
1113 if (validateCerts
&& !d_host
.empty()) {
1114 gnutls_session_set_verify_cert(d_conn
.get(), d_host
.c_str(), GNUTLS_VERIFY_ALLOW_UNSORTED_CHAIN
);
1115 rc
= gnutls_server_name_set(d_conn
.get(), GNUTLS_NAME_DNS
, d_host
.c_str(), d_host
.size());
1116 if (rc
!= GNUTLS_E_SUCCESS
) {
1117 throw std::runtime_error("Error setting the SNI value to '" + d_host
+ "' on TLS connection: " + std::string(gnutls_strerror(rc
)));
1121 /* no hostname validation for you */
1124 /* allow access to our data in the callbacks */
1125 gnutls_session_set_ptr(d_conn
.get(), this);
1126 gnutls_handshake_set_hook_function(d_conn
.get(), GNUTLS_HANDSHAKE_NEW_SESSION_TICKET
, GNUTLS_HOOK_POST
, newTicketFromServerCb
);
1129 /* The callback prototype changed in 3.4.0. */
1130 #if GNUTLS_VERSION_NUMBER >= 0x030400
1131 static int newTicketFromServerCb(gnutls_session_t session
, unsigned int htype
, unsigned post
, unsigned int /* incoming */, const gnutls_datum_t
* /* msg */)
1133 static int newTicketFromServerCb(gnutls_session_t session
, unsigned int htype
, unsigned post
, unsigned int /* incoming */)
1134 #endif /* GNUTLS_VERSION_NUMBER >= 0x030400 */
1136 if (htype
!= GNUTLS_HANDSHAKE_NEW_SESSION_TICKET
|| post
!= GNUTLS_HOOK_POST
|| session
== nullptr) {
1140 GnuTLSConnection
* conn
= reinterpret_cast<GnuTLSConnection
*>(gnutls_session_get_ptr(session
));
1141 if (conn
== nullptr) {
1145 gnutls_datum_t sess
{nullptr, 0};
1146 auto ret
= gnutls_session_get_data2(session
, &sess
);
1147 /* GnuTLS returns a 'fake' ticket of 4 bytes set to zero when there is no ticket available */
1148 if (ret
!= GNUTLS_E_SUCCESS
|| sess
.size
<= 4) {
1149 throw std::runtime_error("Error getting GnuTLSSession: " + std::string(gnutls_strerror(ret
)));
1151 conn
->d_tlsSessions
.push_back(std::make_unique
<GnuTLSSession
>(sess
));
1155 IOState
tryConnect(bool fastOpen
, [[maybe_unused
]] const ComboAddress
& remote
) override
1160 #ifdef HAVE_GNUTLS_TRANSPORT_SET_FASTOPEN
1161 gnutls_transport_set_fastopen(d_conn
.get(), d_socket
, const_cast<struct sockaddr
*>(reinterpret_cast<const struct sockaddr
*>(&remote
)), remote
.getSocklen(), 0);
1166 ret
= gnutls_handshake(d_conn
.get());
1167 if (ret
== GNUTLS_E_SUCCESS
) {
1168 d_handshakeDone
= true;
1169 return IOState::Done
;
1171 else if (ret
== GNUTLS_E_AGAIN
) {
1172 int direction
= gnutls_record_get_direction(d_conn
.get());
1173 return direction
== 0 ? IOState::NeedRead
: IOState::NeedWrite
;
1175 else if (gnutls_error_is_fatal(ret
) || ret
== GNUTLS_E_WARNING_ALERT_RECEIVED
) {
1176 throw std::runtime_error("Error establishing a new connection: " + std::string(gnutls_strerror(ret
)));
1178 } while (ret
== GNUTLS_E_INTERRUPTED
);
1180 throw std::runtime_error("Error establishing a new connection: " + std::string(gnutls_strerror(ret
)));
1183 void connect(bool fastOpen
, const ComboAddress
& remote
, const struct timeval
& timeout
) override
1185 struct timeval start
= {0, 0};
1186 struct timeval remainingTime
= timeout
;
1187 if (timeout
.tv_sec
!= 0 || timeout
.tv_usec
!= 0) {
1188 gettimeofday(&start
, nullptr);
1193 state
= tryConnect(fastOpen
, remote
);
1194 if (state
== IOState::Done
) {
1197 else if (state
== IOState::NeedRead
) {
1198 int result
= waitForData(d_socket
, remainingTime
.tv_sec
, remainingTime
.tv_usec
);
1200 throw std::runtime_error("Error reading from TLS connection: " + std::to_string(result
));
1203 else if (state
== IOState::NeedWrite
) {
1204 int result
= waitForRWData(d_socket
, false, remainingTime
.tv_sec
, remainingTime
.tv_usec
);
1206 throw std::runtime_error("Error reading from TLS connection: " + std::to_string(result
));
1210 if (timeout
.tv_sec
!= 0 || timeout
.tv_usec
!= 0) {
1212 gettimeofday(&now
, nullptr);
1213 struct timeval elapsed
= now
- start
;
1214 if (now
< start
|| remainingTime
< elapsed
) {
1215 throw runtime_error("Timeout while establishing TLS connection");
1218 remainingTime
= remainingTime
- elapsed
;
1221 while (state
!= IOState::Done
);
1224 void doHandshake() override
1228 ret
= gnutls_handshake(d_conn
.get());
1229 if (gnutls_error_is_fatal(ret
) || ret
== GNUTLS_E_WARNING_ALERT_RECEIVED
) {
1231 throw std::runtime_error("Error establishing a new connection: " + std::string(gnutls_strerror(ret
)));
1234 throw std::runtime_error("Error accepting a new connection: " + std::string(gnutls_strerror(ret
)));
1238 while (ret
!= GNUTLS_E_SUCCESS
&& ret
== GNUTLS_E_INTERRUPTED
);
1240 d_handshakeDone
= true;
1243 IOState
tryHandshake() override
1248 ret
= gnutls_handshake(d_conn
.get());
1249 if (ret
== GNUTLS_E_SUCCESS
) {
1250 d_handshakeDone
= true;
1251 return IOState::Done
;
1253 else if (ret
== GNUTLS_E_AGAIN
) {
1254 int direction
= gnutls_record_get_direction(d_conn
.get());
1255 return direction
== 0 ? IOState::NeedRead
: IOState::NeedWrite
;
1257 else if (gnutls_error_is_fatal(ret
) || ret
== GNUTLS_E_WARNING_ALERT_RECEIVED
) {
1260 #ifdef HAVE_GNUTLS_SESSION_GET_VERIFY_CERT_STATUS
1261 if (ret
== GNUTLS_E_CERTIFICATE_VERIFICATION_ERROR
) {
1263 if (gnutls_certificate_verification_status_print(gnutls_session_get_verify_cert_status(d_conn
.get()), gnutls_certificate_type_get(d_conn
.get()), &out
, 0) == 0) {
1264 error
= " (" + std::string(reinterpret_cast<const char*>(out
.data
)) + ")";
1265 gnutls_free(out
.data
);
1268 #endif /* HAVE_GNUTLS_SESSION_GET_VERIFY_CERT_STATUS */
1269 throw std::runtime_error("Error accepting a new connection: " + std::string(gnutls_strerror(ret
)) + error
);
1272 throw std::runtime_error("Error establishing a new connection: " + std::string(gnutls_strerror(ret
)));
1275 } while (ret
== GNUTLS_E_INTERRUPTED
);
1278 throw std::runtime_error("Error establishinging a new connection: " + std::string(gnutls_strerror(ret
)));
1281 throw std::runtime_error("Error accepting a new connection: " + std::string(gnutls_strerror(ret
)));
1285 IOState
tryWrite(const PacketBuffer
& buffer
, size_t& pos
, size_t toWrite
) override
1287 if (!d_handshakeDone
) {
1288 /* As opposed to OpenSSL, GnuTLS will not transparently finish the handshake for us,
1289 we need to keep calling gnutls_handshake() until the handshake has been finished. */
1290 auto state
= tryHandshake();
1291 if (state
!= IOState::Done
) {
1297 ssize_t res
= gnutls_record_send(d_conn
.get(), reinterpret_cast<const char *>(&buffer
.at(pos
)), toWrite
- pos
);
1299 throw std::runtime_error("Error writing to TLS connection");
1302 pos
+= static_cast<size_t>(res
);
1305 if (gnutls_error_is_fatal(res
)) {
1306 throw std::runtime_error("Fatal error writing to TLS connection: " + std::string(gnutls_strerror(res
)));
1308 else if (res
== GNUTLS_E_AGAIN
) {
1309 return IOState::NeedWrite
;
1311 vinfolog("Warning, non-fatal error while writing to TLS connection: %s", gnutls_strerror(res
));
1314 while (pos
< toWrite
);
1315 return IOState::Done
;
1318 IOState
tryRead(PacketBuffer
& buffer
, size_t& pos
, size_t toRead
, bool allowIncomplete
) override
1320 if (!d_handshakeDone
) {
1321 /* As opposed to OpenSSL, GnuTLS will not transparently finish the handshake for us,
1322 we need to keep calling gnutls_handshake() until the handshake has been finished. */
1323 auto state
= tryHandshake();
1324 if (state
!= IOState::Done
) {
1330 ssize_t res
= gnutls_record_recv(d_conn
.get(), reinterpret_cast<char *>(&buffer
.at(pos
)), toRead
- pos
);
1332 throw std::runtime_error("EOF while reading from TLS connection");
1335 pos
+= static_cast<size_t>(res
);
1336 if (allowIncomplete
) {
1341 if (gnutls_error_is_fatal(res
)) {
1342 throw std::runtime_error("Fatal error reading from TLS connection: " + std::string(gnutls_strerror(res
)));
1344 else if (res
== GNUTLS_E_AGAIN
) {
1345 return IOState::NeedRead
;
1347 vinfolog("Warning, non-fatal error while writing to TLS connection: %s", gnutls_strerror(res
));
1350 while (pos
< toRead
);
1351 return IOState::Done
;
1354 size_t read(void* buffer
, size_t bufferSize
, const struct timeval
& readTimeout
, const struct timeval
& totalTimeout
, bool allowIncomplete
) override
1357 struct timeval start
{0,0};
1358 struct timeval remainingTime
= totalTimeout
;
1359 if (totalTimeout
.tv_sec
!= 0 || totalTimeout
.tv_usec
!= 0) {
1360 gettimeofday(&start
, nullptr);
1364 ssize_t res
= gnutls_record_recv(d_conn
.get(), (reinterpret_cast<char *>(buffer
) + got
), bufferSize
- got
);
1366 throw std::runtime_error("EOF while reading from TLS connection");
1369 got
+= static_cast<size_t>(res
);
1370 if (allowIncomplete
) {
1375 if (gnutls_error_is_fatal(res
)) {
1376 throw std::runtime_error("Fatal error reading from TLS connection: " + std::string(gnutls_strerror(res
)));
1378 else if (res
== GNUTLS_E_AGAIN
) {
1379 int result
= waitForData(d_socket
, readTimeout
.tv_sec
, readTimeout
.tv_usec
);
1381 throw std::runtime_error("Error while waiting to read from TLS connection: " + std::to_string(result
));
1385 vinfolog("Non-fatal error while reading from TLS connection: %s", gnutls_strerror(res
));
1389 if (totalTimeout
.tv_sec
!= 0 || totalTimeout
.tv_usec
!= 0) {
1391 gettimeofday(&now
, nullptr);
1392 struct timeval elapsed
= now
- start
;
1393 if (now
< start
|| remainingTime
< elapsed
) {
1394 throw runtime_error("Timeout while reading data");
1397 remainingTime
= remainingTime
- elapsed
;
1400 while (got
< bufferSize
);
1405 size_t write(const void* buffer
, size_t bufferSize
, const struct timeval
& writeTimeout
) override
1410 ssize_t res
= gnutls_record_send(d_conn
.get(), (reinterpret_cast<const char *>(buffer
) + got
), bufferSize
- got
);
1412 throw std::runtime_error("Error writing to TLS connection");
1415 got
+= static_cast<size_t>(res
);
1418 if (gnutls_error_is_fatal(res
)) {
1419 throw std::runtime_error("Fatal error writing to TLS connection: " + std::string(gnutls_strerror(res
)));
1421 else if (res
== GNUTLS_E_AGAIN
) {
1422 int result
= waitForRWData(d_socket
, false, writeTimeout
.tv_sec
, writeTimeout
.tv_usec
);
1424 throw std::runtime_error("Error waiting to write to TLS connection: " + std::to_string(result
));
1428 vinfolog("Non-fatal error while writing to TLS connection: %s", gnutls_strerror(res
));
1432 while (got
< bufferSize
);
1437 bool isUsable() const override
1443 /* as far as I can tell we can't peek so we cannot do better */
1444 return isTCPSocketUsable(d_socket
);
1447 std::string
getServerNameIndication() const override
1451 size_t name_len
= 256;
1453 sni
.resize(name_len
);
1455 int res
= gnutls_server_name_get(d_conn
.get(), const_cast<char*>(sni
.c_str()), &name_len
, &type
, 0);
1456 if (res
== GNUTLS_E_SUCCESS
) {
1457 sni
.resize(name_len
);
1461 return std::string();
1464 std::vector
<uint8_t> getNextProtocol() const override
1466 std::vector
<uint8_t> result
;
1470 gnutls_datum_t next
;
1471 if (gnutls_alpn_get_selected_protocol(d_conn
.get(), &next
) != GNUTLS_E_SUCCESS
) {
1474 result
.insert(result
.end(), next
.data
, next
.data
+ next
.size
);
1478 LibsslTLSVersion
getTLSVersion() const override
1480 auto proto
= gnutls_protocol_get_version(d_conn
.get());
1483 return LibsslTLSVersion::TLS10
;
1485 return LibsslTLSVersion::TLS11
;
1487 return LibsslTLSVersion::TLS12
;
1488 #if GNUTLS_VERSION_NUMBER >= 0x030603
1490 return LibsslTLSVersion::TLS13
;
1491 #endif /* GNUTLS_VERSION_NUMBER >= 0x030603 */
1493 return LibsslTLSVersion::Unknown
;
1497 bool hasSessionBeenResumed() const override
1500 return gnutls_session_is_resumed(d_conn
.get()) != 0;
1505 std::vector
<std::unique_ptr
<TLSSession
>> getSessions() override
1507 return std::move(d_tlsSessions
);
1510 void setSession(std::unique_ptr
<TLSSession
>& session
) override
1512 auto sess
= dynamic_cast<GnuTLSSession
*>(session
.get());
1514 throw std::runtime_error("Unable to convert GnuTLS session");
1517 auto native
= sess
->getNative();
1518 auto ret
= gnutls_session_set_data(d_conn
.get(), native
.data
, native
.size
);
1519 if (ret
!= GNUTLS_E_SUCCESS
) {
1520 throw std::runtime_error("Error setting up GnuTLS session: " + std::string(gnutls_strerror(ret
)));
1525 void close() override
1528 gnutls_bye(d_conn
.get(), GNUTLS_SHUT_RDWR
);
1532 bool setALPNProtos(const std::vector
<std::vector
<uint8_t>>& protos
)
1534 std::vector
<gnutls_datum_t
> values
;
1535 values
.reserve(protos
.size());
1536 for (const auto& proto
: protos
) {
1537 gnutls_datum_t value
;
1538 value
.data
= const_cast<uint8_t*>(proto
.data());
1539 value
.size
= proto
.size();
1540 values
.push_back(value
);
1542 unsigned int flags
= 0;
1543 #if GNUTLS_VERSION_NUMBER >= 0x030500
1544 flags
|= GNUTLS_ALPN_MANDATORY
;
1545 #elif defined(GNUTLS_ALPN_MAND)
1546 flags
|= GNUTLS_ALPN_MAND
;
1548 return gnutls_alpn_set_protocols(d_conn
.get(), values
.data(), values
.size(), flags
);
1551 std::vector
<int> getAsyncFDs() override
1557 std::shared_ptr
<gnutls_certificate_credentials_st
> d_creds
;
1558 std::shared_ptr
<GnuTLSTicketsKey
> d_ticketsKey
;
1559 std::unique_ptr
<gnutls_session_int
, void(*)(gnutls_session_t
)> d_conn
;
1560 std::vector
<std::unique_ptr
<TLSSession
>> d_tlsSessions
;
1562 bool d_client
{false};
1563 bool d_handshakeDone
{false};
1566 class GnuTLSIOCtx
: public TLSCtx
1569 /* server side context */
1570 GnuTLSIOCtx(TLSFrontend
& fe
): d_enableTickets(fe
.d_tlsConfig
.d_enableTickets
)
1573 d_ticketsKeyRotationDelay
= fe
.d_tlsConfig
.d_ticketsKeyRotationDelay
;
1575 gnutls_certificate_credentials_t creds
;
1576 rc
= gnutls_certificate_allocate_credentials(&creds
);
1577 if (rc
!= GNUTLS_E_SUCCESS
) {
1578 throw std::runtime_error("Error allocating credentials for TLS context on " + fe
.d_addr
.toStringWithPort() + ": " + gnutls_strerror(rc
));
1581 d_creds
= std::shared_ptr
<gnutls_certificate_credentials_st
>(creds
, gnutls_certificate_free_credentials
);
1584 for (const auto& pair
: fe
.d_tlsConfig
.d_certKeyPairs
) {
1585 rc
= gnutls_certificate_set_x509_key_file(d_creds
.get(), pair
.d_cert
.c_str(), pair
.d_key
->c_str(), GNUTLS_X509_FMT_PEM
);
1586 if (rc
!= GNUTLS_E_SUCCESS
) {
1587 throw std::runtime_error("Error loading certificate ('" + pair
.d_cert
+ "') and key ('" + pair
.d_key
.value() + "') for TLS context on " + fe
.d_addr
.toStringWithPort() + ": " + gnutls_strerror(rc
));
1591 #ifndef DISABLE_OCSP_STAPLING
1593 for (const auto& file
: fe
.d_tlsConfig
.d_ocspFiles
) {
1594 rc
= gnutls_certificate_set_ocsp_status_request_file(d_creds
.get(), file
.c_str(), count
);
1595 if (rc
!= GNUTLS_E_SUCCESS
) {
1596 warnlog("Error loading OCSP response from file '%s' for certificate ('%s') and key ('%s') for TLS context on %s: %s", file
, fe
.d_tlsConfig
.d_certKeyPairs
.at(count
).d_cert
, fe
.d_tlsConfig
.d_certKeyPairs
.at(count
).d_key
.value(), fe
.d_addr
.toStringWithPort(), gnutls_strerror(rc
));
1600 #endif /* DISABLE_OCSP_STAPLING */
1602 #if GNUTLS_VERSION_NUMBER >= 0x030600
1603 rc
= gnutls_certificate_set_known_dh_params(d_creds
.get(), GNUTLS_SEC_PARAM_HIGH
);
1604 if (rc
!= GNUTLS_E_SUCCESS
) {
1605 throw std::runtime_error("Error setting DH params for TLS context on " + fe
.d_addr
.toStringWithPort() + ": " + gnutls_strerror(rc
));
1609 rc
= gnutls_priority_init(&d_priorityCache
, fe
.d_tlsConfig
.d_ciphers
.empty() ? "NORMAL" : fe
.d_tlsConfig
.d_ciphers
.c_str(), nullptr);
1610 if (rc
!= GNUTLS_E_SUCCESS
) {
1611 throw std::runtime_error("Error setting up TLS cipher preferences to '" + fe
.d_tlsConfig
.d_ciphers
+ "' (" + gnutls_strerror(rc
) + ") on " + fe
.d_addr
.toStringWithPort());
1615 if (fe
.d_tlsConfig
.d_ticketKeyFile
.empty()) {
1616 handleTicketsKeyRotation(time(nullptr));
1619 GnuTLSIOCtx::loadTicketsKeys(fe
.d_tlsConfig
.d_ticketKeyFile
);
1622 catch(const std::runtime_error
& e
) {
1623 throw std::runtime_error("Error generating tickets key for TLS context on " + fe
.d_addr
.toStringWithPort() + ": " + e
.what());
1627 /* client side context */
1628 GnuTLSIOCtx(const TLSContextParameters
& params
): d_contextParameters(std::make_unique
<TLSContextParameters
>(params
)), d_enableTickets(true), d_validateCerts(params
.d_validateCertificates
)
1632 gnutls_certificate_credentials_t creds
;
1633 rc
= gnutls_certificate_allocate_credentials(&creds
);
1634 if (rc
!= GNUTLS_E_SUCCESS
) {
1635 throw std::runtime_error("Error allocating credentials for TLS context: " + std::string(gnutls_strerror(rc
)));
1638 d_creds
= std::shared_ptr
<gnutls_certificate_credentials_st
>(creds
, gnutls_certificate_free_credentials
);
1641 if (params
.d_validateCertificates
) {
1642 if (params
.d_caStore
.empty()) {
1643 #if GNUTLS_VERSION_NUMBER >= 0x030700 && GNUTLS_VERSION_NUMBER < 0x030703
1644 /* see https://gitlab.com/gnutls/gnutls/-/issues/1277 */
1645 std::cerr
<<"Warning: GnuTLS 3.7.0 - 3.7.2 have a memory leak when validating server certificates in some configurations (PKCS11 support enabled, and a default PKCS11 trust store), please consider upgrading GnuTLS, using the OpenSSL provider for outgoing connections, or explicitly setting a CA store"<<std::endl
;
1646 #endif /* GNUTLS_VERSION_NUMBER >= 0x030700 && GNUTLS_VERSION_NUMBER < 0x030703 */
1647 rc
= gnutls_certificate_set_x509_system_trust(d_creds
.get());
1649 throw std::runtime_error("Error adding the system's default trusted CAs: " + std::string(gnutls_strerror(rc
)));
1653 rc
= gnutls_certificate_set_x509_trust_file(d_creds
.get(), params
.d_caStore
.c_str(), GNUTLS_X509_FMT_PEM
);
1655 throw std::runtime_error("Error adding '" + params
.d_caStore
+ "' to the trusted CAs: " + std::string(gnutls_strerror(rc
)));
1660 rc
= gnutls_priority_init(&d_priorityCache
, params
.d_ciphers
.empty() ? "NORMAL" : params
.d_ciphers
.c_str(), nullptr);
1661 if (rc
!= GNUTLS_E_SUCCESS
) {
1662 throw std::runtime_error("Error setting up TLS cipher preferences to 'NORMAL' (" + std::string(gnutls_strerror(rc
)) + ")");
1666 ~GnuTLSIOCtx() override
1670 if (d_priorityCache
) {
1671 gnutls_priority_deinit(d_priorityCache
);
1675 std::unique_ptr
<TLSConnection
> getConnection(int socket
, const struct timeval
& timeout
, time_t now
) override
1677 handleTicketsKeyRotation(now
);
1679 std::shared_ptr
<GnuTLSTicketsKey
> ticketsKey
;
1681 ticketsKey
= *(d_ticketsKey
.read_lock());
1684 auto connection
= std::make_unique
<GnuTLSConnection
>(socket
, timeout
, d_creds
, d_priorityCache
, ticketsKey
, d_enableTickets
);
1685 if (!d_protos
.empty()) {
1686 connection
->setALPNProtos(d_protos
);
1691 static std::shared_ptr
<gnutls_certificate_credentials_st
> getPerThreadCredentials(bool validate
, const std::string
& caStore
)
1693 static thread_local
std::map
<std::pair
<bool, std::string
>, std::shared_ptr
<gnutls_certificate_credentials_st
>> t_credentials
;
1694 auto& entry
= t_credentials
[{validate
, caStore
}];
1696 gnutls_certificate_credentials_t creds
;
1697 int rc
= gnutls_certificate_allocate_credentials(&creds
);
1698 if (rc
!= GNUTLS_E_SUCCESS
) {
1699 throw std::runtime_error("Error allocating credentials for TLS context: " + std::string(gnutls_strerror(rc
)));
1702 entry
= std::shared_ptr
<gnutls_certificate_credentials_st
>(creds
, gnutls_certificate_free_credentials
);
1706 if (caStore
.empty()) {
1707 rc
= gnutls_certificate_set_x509_system_trust(entry
.get());
1709 throw std::runtime_error("Error adding the system's default trusted CAs: " + std::string(gnutls_strerror(rc
)));
1713 rc
= gnutls_certificate_set_x509_trust_file(entry
.get(), caStore
.c_str(), GNUTLS_X509_FMT_PEM
);
1715 throw std::runtime_error("Error adding '" + caStore
+ "' to the trusted CAs: " + std::string(gnutls_strerror(rc
)));
1723 std::unique_ptr
<TLSConnection
> getClientConnection(const std::string
& host
, bool, int socket
, const struct timeval
& timeout
) override
1725 auto creds
= getPerThreadCredentials(d_contextParameters
->d_validateCertificates
, d_contextParameters
->d_caStore
);
1726 auto connection
= std::make_unique
<GnuTLSConnection
>(host
, socket
, timeout
, creds
, d_priorityCache
, d_validateCerts
);
1727 if (!d_protos
.empty()) {
1728 connection
->setALPNProtos(d_protos
);
1733 void rotateTicketsKey(time_t now
) override
1735 if (!d_enableTickets
) {
1739 auto newKey
= std::make_shared
<GnuTLSTicketsKey
>();
1742 *(d_ticketsKey
.write_lock()) = std::move(newKey
);
1745 if (d_ticketsKeyRotationDelay
> 0) {
1746 d_ticketsKeyNextRotation
= now
+ d_ticketsKeyRotationDelay
;
1750 void loadTicketsKeys(const std::string
& file
) final
1752 if (!d_enableTickets
) {
1756 auto newKey
= std::make_shared
<GnuTLSTicketsKey
>(file
);
1758 *(d_ticketsKey
.write_lock()) = std::move(newKey
);
1761 if (d_ticketsKeyRotationDelay
> 0) {
1762 d_ticketsKeyNextRotation
= time(nullptr) + d_ticketsKeyRotationDelay
;
1766 size_t getTicketsKeysCount() override
1768 return *(d_ticketsKey
.read_lock()) != nullptr ? 1 : 0;
1771 std::string
getName() const override
1776 bool setALPNProtos(const std::vector
<std::vector
<uint8_t>>& protos
) override
1778 #ifdef HAVE_GNUTLS_ALPN_SET_PROTOCOLS
1787 /* client context parameters */
1788 std::unique_ptr
<TLSContextParameters
> d_contextParameters
{nullptr};
1789 std::shared_ptr
<gnutls_certificate_credentials_st
> d_creds
;
1790 std::vector
<std::vector
<uint8_t>> d_protos
;
1791 gnutls_priority_t d_priorityCache
{nullptr};
1792 SharedLockGuarded
<std::shared_ptr
<GnuTLSTicketsKey
>> d_ticketsKey
{nullptr};
1793 bool d_enableTickets
{true};
1794 bool d_validateCerts
{true};
1797 #endif /* HAVE_GNUTLS */
1799 #endif /* HAVE_DNS_OVER_TLS || HAVE_DNS_OVER_HTTPS */
1801 bool setupDoTProtocolNegotiation(std::shared_ptr
<TLSCtx
>& ctx
)
1803 if (ctx
== nullptr) {
1806 /* we want to set the ALPN to dot (RFC7858), if only to mitigate the ALPACA attack */
1807 const std::vector
<std::vector
<uint8_t>> dotAlpns
= {{'d', 'o', 't'}};
1808 ctx
->setALPNProtos(dotAlpns
);
1812 bool setupDoHProtocolNegotiation(std::shared_ptr
<TLSCtx
>& ctx
)
1814 if (ctx
== nullptr) {
1817 /* This code is only called for incoming/server TLS contexts (not outgoing/client),
1818 and h2o sets it own ALPN values.
1819 We want to set the ALPN for DoH:
1820 - HTTP/1.1 so that the OpenSSL callback ALPN accepts it, letting us later return a static response
1823 const std::vector
<std::vector
<uint8_t>> dohAlpns
{{'h', '2'},{'h', 't', 't', 'p', '/', '1', '.', '1'}};
1824 ctx
->setALPNProtos(dohAlpns
);
1829 bool TLSFrontend::setupTLS()
1831 #if defined(HAVE_DNS_OVER_TLS) || defined(HAVE_DNS_OVER_HTTPS)
1832 std::shared_ptr
<TLSCtx
> newCtx
{nullptr};
1833 /* get the "best" available provider */
1834 #if defined(HAVE_GNUTLS)
1835 if (d_provider
== "gnutls") {
1836 newCtx
= std::make_shared
<GnuTLSIOCtx
>(*this);
1838 #endif /* HAVE_GNUTLS */
1839 #if defined(HAVE_LIBSSL)
1840 if (d_provider
== "openssl") {
1841 newCtx
= std::make_shared
<OpenSSLTLSIOCtx
>(*this);
1843 #endif /* HAVE_LIBSSL */
1846 #if defined(HAVE_LIBSSL)
1847 newCtx
= std::make_shared
<OpenSSLTLSIOCtx
>(*this);
1848 #elif defined(HAVE_GNUTLS)
1849 newCtx
= std::make_shared
<GnuTLSIOCtx
>(*this);
1851 #error "TLS support needed but neither libssl nor GnuTLS were selected"
1855 if (d_alpn
== ALPN::DoT
) {
1856 setupDoTProtocolNegotiation(newCtx
);
1858 else if (d_alpn
== ALPN::DoH
) {
1859 setupDoHProtocolNegotiation(newCtx
);
1862 std::atomic_store_explicit(&d_ctx
, std::move(newCtx
), std::memory_order_release
);
1863 #endif /* HAVE_DNS_OVER_TLS || HAVE_DNS_OVER_HTTPS */
1867 std::shared_ptr
<TLSCtx
> getTLSContext([[maybe_unused
]] const TLSContextParameters
& params
)
1869 #ifdef HAVE_DNS_OVER_TLS
1870 /* get the "best" available provider */
1871 if (!params
.d_provider
.empty()) {
1872 #if defined(HAVE_GNUTLS)
1873 if (params
.d_provider
== "gnutls") {
1874 return std::make_shared
<GnuTLSIOCtx
>(params
);
1876 #endif /* HAVE_GNUTLS */
1877 #if defined(HAVE_LIBSSL)
1878 if (params
.d_provider
== "openssl") {
1879 return std::make_shared
<OpenSSLTLSIOCtx
>(params
);
1881 #endif /* HAVE_LIBSSL */
1884 #if defined(HAVE_LIBSSL)
1885 return std::make_shared
<OpenSSLTLSIOCtx
>(params
);
1886 #elif defined(HAVE_GNUTLS)
1887 return std::make_shared
<GnuTLSIOCtx
>(params
);
1889 #error "DNS over TLS support needed but neither libssl nor GnuTLS were selected"
1892 #endif /* HAVE_DNS_OVER_TLS */