]> git.ipfire.org Git - thirdparty/pdns.git/blob - pdns/tcpiohandler.cc
rec: mention rust compiler in compiling docs
[thirdparty/pdns.git] / pdns / tcpiohandler.cc
1
2 #include "config.h"
3 #include "dolog.hh"
4 #include "iputils.hh"
5 #include "lock.hh"
6 #include "tcpiohandler.hh"
7
8 const bool TCPIOHandler::s_disableConnectForUnitTests = false;
9
10 #ifdef HAVE_LIBSODIUM
11 #include <sodium.h>
12 #endif /* HAVE_LIBSODIUM */
13
14 #if defined(HAVE_DNS_OVER_TLS) || defined(HAVE_DNS_OVER_HTTPS)
15 #ifdef HAVE_LIBSSL
16
17 #include <openssl/conf.h>
18 #include <openssl/err.h>
19 #include <openssl/rand.h>
20 #include <openssl/ssl.h>
21 #include <openssl/x509v3.h>
22
23 #include "libssl.hh"
24
25
26 class OpenSSLFrontendContext
27 {
28 public:
29 OpenSSLFrontendContext(const ComboAddress& addr, const TLSConfig& tlsConfig): d_ticketKeys(tlsConfig.d_numberOfTicketsKeys)
30 {
31 registerOpenSSLUser();
32
33 auto [ctx, warnings] = libssl_init_server_context(tlsConfig, d_ocspResponses);
34 for (const auto& warning : warnings) {
35 warnlog("%s", warning);
36 }
37 d_tlsCtx = std::move(ctx);
38
39 if (!d_tlsCtx) {
40 ERR_print_errors_fp(stderr);
41 throw std::runtime_error("Error creating TLS context on " + addr.toStringWithPort());
42 }
43 }
44
45 void cleanup()
46 {
47 d_tlsCtx.reset();
48
49 unregisterOpenSSLUser();
50 }
51
52 OpenSSLTLSTicketKeysRing d_ticketKeys;
53 std::map<int, std::string> d_ocspResponses;
54 std::unique_ptr<SSL_CTX, void(*)(SSL_CTX*)> d_tlsCtx{nullptr, SSL_CTX_free};
55 pdns::UniqueFilePtr d_keyLogFile{nullptr};
56 };
57
58 class OpenSSLSession : public TLSSession
59 {
60 public:
61 OpenSSLSession(std::unique_ptr<SSL_SESSION, void(*)(SSL_SESSION*)>&& sess): d_sess(std::move(sess))
62 {
63 }
64
65 std::unique_ptr<SSL_SESSION, void(*)(SSL_SESSION*)> getNative()
66 {
67 return std::move(d_sess);
68 }
69
70 private:
71 std::unique_ptr<SSL_SESSION, void(*)(SSL_SESSION*)> d_sess;
72 };
73
74 class OpenSSLTLSConnection: public TLSConnection
75 {
76 public:
77 /* server side connection */
78 OpenSSLTLSConnection(int socket, const struct timeval& timeout, std::shared_ptr<OpenSSLFrontendContext> feContext): d_feContext(std::move(feContext)), d_conn(std::unique_ptr<SSL, void(*)(SSL*)>(SSL_new(d_feContext->d_tlsCtx.get()), SSL_free)), d_timeout(timeout)
79 {
80 d_socket = socket;
81
82 if (!d_conn) {
83 vinfolog("Error creating TLS object");
84 if (g_verbose) {
85 ERR_print_errors_fp(stderr);
86 }
87 throw std::runtime_error("Error creating TLS object");
88 }
89
90 if (!SSL_set_fd(d_conn.get(), d_socket)) {
91 throw std::runtime_error("Error assigning socket");
92 }
93
94 SSL_set_ex_data(d_conn.get(), getConnectionIndex(), this);
95 }
96
97 /* client-side connection */
98 OpenSSLTLSConnection(const std::string& hostname, bool hostIsAddr, int socket, const struct timeval& timeout, std::shared_ptr<SSL_CTX>& tlsCtx): d_tlsCtx(tlsCtx), d_conn(std::unique_ptr<SSL, void(*)(SSL*)>(SSL_new(tlsCtx.get()), SSL_free)), d_hostname(hostname), d_timeout(timeout)
99 {
100 d_socket = socket;
101
102 if (!d_conn) {
103 vinfolog("Error creating TLS object");
104 if (g_verbose) {
105 ERR_print_errors_fp(stderr);
106 }
107 throw std::runtime_error("Error creating TLS object");
108 }
109
110 if (!SSL_set_fd(d_conn.get(), d_socket)) {
111 throw std::runtime_error("Error assigning socket");
112 }
113
114 /* set outgoing Server Name Indication */
115 if (!d_hostname.empty() && SSL_set_tlsext_host_name(d_conn.get(), d_hostname.c_str()) != 1) {
116 throw std::runtime_error("Error setting TLS SNI to " + d_hostname);
117 }
118
119 if (hostIsAddr) {
120 #if (OPENSSL_VERSION_NUMBER >= 0x10002000L)
121 X509_VERIFY_PARAM *param = SSL_get0_param(d_conn.get());
122 /* Enable automatic IP checks */
123 X509_VERIFY_PARAM_set_hostflags(param, X509_CHECK_FLAG_NO_PARTIAL_WILDCARDS);
124 if (X509_VERIFY_PARAM_set1_ip_asc(param, d_hostname.c_str()) != 1) {
125 throw std::runtime_error("Error setting TLS IP for certificate validation");
126 }
127 #else
128 /* no validation for you, see https://wiki.openssl.org/index.php/Hostname_validation */
129 #endif
130 }
131 else {
132 #if (OPENSSL_VERSION_NUMBER >= 0x1010000fL) && defined(HAVE_SSL_SET_HOSTFLAGS) // grrr libressl
133 SSL_set_hostflags(d_conn.get(), X509_CHECK_FLAG_NO_PARTIAL_WILDCARDS);
134 if (SSL_set1_host(d_conn.get(), d_hostname.c_str()) != 1) {
135 throw std::runtime_error("Error setting TLS hostname for certificate validation");
136 }
137 #elif (OPENSSL_VERSION_NUMBER >= 0x10002000L)
138 X509_VERIFY_PARAM *param = SSL_get0_param(d_conn.get());
139 /* Enable automatic hostname checks */
140 X509_VERIFY_PARAM_set_hostflags(param, X509_CHECK_FLAG_NO_PARTIAL_WILDCARDS);
141 if (X509_VERIFY_PARAM_set1_host(param, d_hostname.c_str(), d_hostname.size()) != 1) {
142 throw std::runtime_error("Error setting TLS hostname for certificate validation");
143 }
144 #else
145 /* no hostname validation for you, see https://wiki.openssl.org/index.php/Hostname_validation */
146 #endif
147 }
148
149 SSL_set_ex_data(d_conn.get(), getConnectionIndex(), this);
150 }
151
152 std::vector<int> getAsyncFDs() override
153 {
154 std::vector<int> results;
155 #ifdef SSL_MODE_ASYNC
156 if (SSL_waiting_for_async(d_conn.get()) != 1) {
157 return results;
158 }
159
160 OSSL_ASYNC_FD fds[32];
161 size_t numfds = sizeof(fds)/sizeof(*fds);
162 SSL_get_all_async_fds(d_conn.get(), nullptr, &numfds);
163 if (numfds == 0) {
164 return results;
165 }
166
167 SSL_get_all_async_fds(d_conn.get(), fds, &numfds);
168 results.reserve(numfds);
169 for (size_t idx = 0; idx < numfds; idx++) {
170 results.push_back(fds[idx]);
171 }
172 #endif
173 return results;
174 }
175
176 IOState convertIORequestToIOState(int res) const
177 {
178 int error = SSL_get_error(d_conn.get(), res);
179 if (error == SSL_ERROR_WANT_READ) {
180 return IOState::NeedRead;
181 }
182 else if (error == SSL_ERROR_WANT_WRITE) {
183 return IOState::NeedWrite;
184 }
185 else if (error == SSL_ERROR_SYSCALL) {
186 if (errno == 0) {
187 throw std::runtime_error("TLS connection closed by remote end");
188 }
189 else {
190 throw std::runtime_error("Syscall error while processing TLS connection: " + std::string(strerror(errno)));
191 }
192 }
193 else if (error == SSL_ERROR_ZERO_RETURN) {
194 throw std::runtime_error("TLS connection closed by remote end");
195 }
196 #ifdef SSL_MODE_ASYNC
197 else if (error == SSL_ERROR_WANT_ASYNC) {
198 return IOState::Async;
199 }
200 #endif
201 else {
202 if (g_verbose) {
203 throw std::runtime_error("Error while processing TLS connection: (" + std::to_string(error) + ") " + libssl_get_error_string());
204 } else {
205 throw std::runtime_error("Error while processing TLS connection: " + std::to_string(error));
206 }
207 }
208 }
209
210 void handleIORequest(int res, const struct timeval& timeout)
211 {
212 auto state = convertIORequestToIOState(res);
213 if (state == IOState::NeedRead) {
214 res = waitForData(d_socket, timeout.tv_sec, timeout.tv_usec);
215 if (res == 0) {
216 throw std::runtime_error("Timeout while reading from TLS connection");
217 }
218 else if (res < 0) {
219 throw std::runtime_error("Error waiting to read from TLS connection");
220 }
221 }
222 else if (state == IOState::NeedWrite) {
223 res = waitForRWData(d_socket, false, timeout.tv_sec, timeout.tv_usec);
224 if (res == 0) {
225 throw std::runtime_error("Timeout while writing to TLS connection");
226 }
227 else if (res < 0) {
228 throw std::runtime_error("Error waiting to write to TLS connection");
229 }
230 }
231 }
232
233 IOState tryConnect(bool fastOpen, const ComboAddress& remote) override
234 {
235 /* sorry */
236 (void) fastOpen;
237 (void) remote;
238
239 int res = SSL_connect(d_conn.get());
240 if (res == 1) {
241 return IOState::Done;
242 }
243 else if (res < 0) {
244 return convertIORequestToIOState(res);
245 }
246
247 throw std::runtime_error("Error establishing a TLS connection");
248 }
249
250 void connect(bool fastOpen, const ComboAddress& remote, const struct timeval &timeout) override
251 {
252 /* sorry */
253 (void) fastOpen;
254 (void) remote;
255
256 struct timeval start{0,0};
257 struct timeval remainingTime = timeout;
258 if (timeout.tv_sec != 0 || timeout.tv_usec != 0) {
259 gettimeofday(&start, nullptr);
260 }
261
262 int res = 0;
263 do {
264 res = SSL_connect(d_conn.get());
265 if (res < 0) {
266 handleIORequest(res, remainingTime);
267 }
268
269 if (timeout.tv_sec != 0 || timeout.tv_usec != 0) {
270 struct timeval now;
271 gettimeofday(&now, nullptr);
272 struct timeval elapsed = now - start;
273 if (now < start || remainingTime < elapsed) {
274 throw runtime_error("Timeout while establishing TLS connection");
275 }
276 start = now;
277 remainingTime = remainingTime - elapsed;
278 }
279 }
280 while (res != 1);
281 }
282
283 IOState tryHandshake() override
284 {
285 if (!d_feContext) {
286 /* In client mode, the handshake is initiated by the call to SSL_connect()
287 done from connect()/tryConnect().
288 In blocking mode it does not return before the handshake has been finished,
289 and in non-blocking mode calling SSL_connect() once is enough for SSL_write()
290 and SSL_read() to transparently continue to negotiate the connection after that
291 (equivalent to doing SSL_set_connect_state() plus trying to write).
292 */
293 return IOState::Done;
294 }
295
296 /* As explained above in the client-mode block, we only need to call SSL_accept() once
297 for SSL_write() and SSL_read() to transparently continue to negotiate the connection after that.
298 It is equivalent to calling SSL_set_accept_state() plus trying to read.
299 */
300 int res = SSL_accept(d_conn.get());
301 if (res == 1) {
302 return IOState::Done;
303 }
304 else if (res < 0) {
305 return convertIORequestToIOState(res);
306 }
307
308 throw std::runtime_error("Error accepting TLS connection");
309 }
310
311 void doHandshake() override
312 {
313 if (!d_feContext) {
314 /* we are a client, nothing to do, see the non-blocking version */
315 return;
316 }
317
318 int res = 0;
319 do {
320 res = SSL_accept(d_conn.get());
321 if (res < 0) {
322 handleIORequest(res, d_timeout);
323 }
324 }
325 while (res < 0);
326
327 if (res != 1) {
328 throw std::runtime_error("Error accepting TLS connection");
329 }
330 }
331
332 IOState tryWrite(const PacketBuffer& buffer, size_t& pos, size_t toWrite) override
333 {
334 if (!d_feContext && !d_connected) {
335 if (d_ktls) {
336 /* work-around to get kTLS to be started, as we cannot do that until after the socket has been connected */
337 SSL_set_fd(d_conn.get(), SSL_get_fd(d_conn.get()));
338 }
339 }
340
341 do {
342 int res = SSL_write(d_conn.get(), reinterpret_cast<const char *>(&buffer.at(pos)), static_cast<int>(toWrite - pos));
343 if (res <= 0) {
344 return convertIORequestToIOState(res);
345 }
346 else {
347 pos += static_cast<size_t>(res);
348 }
349 }
350 while (pos < toWrite);
351
352 if (!d_connected) {
353 d_connected = true;
354 }
355
356 return IOState::Done;
357 }
358
359 IOState tryRead(PacketBuffer& buffer, size_t& pos, size_t toRead, bool allowIncomplete) override
360 {
361 do {
362 int res = SSL_read(d_conn.get(), reinterpret_cast<char *>(&buffer.at(pos)), static_cast<int>(toRead - pos));
363 if (res <= 0) {
364 return convertIORequestToIOState(res);
365 }
366 else {
367 pos += static_cast<size_t>(res);
368 if (allowIncomplete) {
369 break;
370 }
371 }
372 }
373 while (pos < toRead);
374 return IOState::Done;
375 }
376
377 size_t read(void* buffer, size_t bufferSize, const struct timeval& readTimeout, const struct timeval& totalTimeout, bool allowIncomplete) override
378 {
379 size_t got = 0;
380 struct timeval start = {0, 0};
381 struct timeval remainingTime = totalTimeout;
382 if (totalTimeout.tv_sec != 0 || totalTimeout.tv_usec != 0) {
383 gettimeofday(&start, nullptr);
384 }
385
386 do {
387 int res = SSL_read(d_conn.get(), (reinterpret_cast<char *>(buffer) + got), static_cast<int>(bufferSize - got));
388 if (res <= 0) {
389 handleIORequest(res, readTimeout);
390 }
391 else {
392 got += static_cast<size_t>(res);
393 if (allowIncomplete) {
394 break;
395 }
396 }
397
398 if (totalTimeout.tv_sec != 0 || totalTimeout.tv_usec != 0) {
399 struct timeval now;
400 gettimeofday(&now, nullptr);
401 struct timeval elapsed = now - start;
402 if (now < start || remainingTime < elapsed) {
403 throw runtime_error("Timeout while reading data");
404 }
405 start = now;
406 remainingTime = remainingTime - elapsed;
407 }
408 }
409 while (got < bufferSize);
410
411 return got;
412 }
413
414 size_t write(const void* buffer, size_t bufferSize, const struct timeval& writeTimeout) override
415 {
416 size_t got = 0;
417 do {
418 int res = SSL_write(d_conn.get(), (reinterpret_cast<const char *>(buffer) + got), static_cast<int>(bufferSize - got));
419 if (res <= 0) {
420 handleIORequest(res, writeTimeout);
421 }
422 else {
423 got += static_cast<size_t>(res);
424 }
425 }
426 while (got < bufferSize);
427
428 return got;
429 }
430
431 bool isUsable() const override
432 {
433 if (!d_conn) {
434 return false;
435 }
436
437 char buf;
438 int res = SSL_peek(d_conn.get(), &buf, sizeof(buf));
439 if (res > 0) {
440 return true;
441 }
442 try {
443 convertIORequestToIOState(res);
444 return true;
445 }
446 catch (...) {
447 return false;
448 }
449
450 return false;
451 }
452
453 void close() override
454 {
455 if (d_conn) {
456 SSL_shutdown(d_conn.get());
457 }
458 }
459
460 std::string getServerNameIndication() const override
461 {
462 if (d_conn) {
463 const char* value = SSL_get_servername(d_conn.get(), TLSEXT_NAMETYPE_host_name);
464 if (value) {
465 return std::string(value);
466 }
467 }
468 return std::string();
469 }
470
471 std::vector<uint8_t> getNextProtocol() const override
472 {
473 std::vector<uint8_t> result;
474 if (!d_conn) {
475 return result;
476 }
477
478 const unsigned char* alpn = nullptr;
479 unsigned int alpnLen = 0;
480 #ifndef DISABLE_NPN
481 #ifdef HAVE_SSL_GET0_NEXT_PROTO_NEGOTIATED
482 SSL_get0_next_proto_negotiated(d_conn.get(), &alpn, &alpnLen);
483 #endif /* HAVE_SSL_GET0_NEXT_PROTO_NEGOTIATED */
484 #endif /* DISABLE_NPN */
485 #ifdef HAVE_SSL_GET0_ALPN_SELECTED
486 if (alpn == nullptr) {
487 SSL_get0_alpn_selected(d_conn.get(), &alpn, &alpnLen);
488 }
489 #endif /* HAVE_SSL_GET0_ALPN_SELECTED */
490 if (alpn != nullptr && alpnLen > 0) {
491 result.insert(result.end(), alpn, alpn + alpnLen);
492 }
493 return result;
494 }
495
496 LibsslTLSVersion getTLSVersion() const override
497 {
498 auto proto = SSL_version(d_conn.get());
499 switch (proto) {
500 case TLS1_VERSION:
501 return LibsslTLSVersion::TLS10;
502 case TLS1_1_VERSION:
503 return LibsslTLSVersion::TLS11;
504 case TLS1_2_VERSION:
505 return LibsslTLSVersion::TLS12;
506 #ifdef TLS1_3_VERSION
507 case TLS1_3_VERSION:
508 return LibsslTLSVersion::TLS13;
509 #endif /* TLS1_3_VERSION */
510 default:
511 return LibsslTLSVersion::Unknown;
512 }
513 }
514
515 bool hasSessionBeenResumed() const override
516 {
517 if (d_conn) {
518 return SSL_session_reused(d_conn.get()) != 0;
519 }
520 return false;
521 }
522
523 std::vector<std::unique_ptr<TLSSession>> getSessions() override
524 {
525 return std::move(d_tlsSessions);
526 }
527
528 void setSession(std::unique_ptr<TLSSession>& session) override
529 {
530 auto sess = dynamic_cast<OpenSSLSession*>(session.get());
531 if (!sess) {
532 throw std::runtime_error("Unable to convert OpenSSL session");
533 }
534
535 auto native = sess->getNative();
536 auto ret = SSL_set_session(d_conn.get(), native.get());
537 if (ret != 1) {
538 throw std::runtime_error("Error setting up session: " + libssl_get_error_string());
539 }
540 session.reset();
541 }
542
543 void addNewTicket(SSL_SESSION* session)
544 {
545 d_tlsSessions.push_back(std::make_unique<OpenSSLSession>(std::unique_ptr<SSL_SESSION, void (*)(SSL_SESSION*)>(session, SSL_SESSION_free)));
546 }
547
548 void enableKTLS()
549 {
550 d_ktls = true;
551 }
552
553 static void generateConnectionIndexIfNeeded()
554 {
555 auto init = s_initTLSConnIndex.lock();
556 if (*init == true) {
557 return;
558 }
559
560 /* not initialized yet */
561 s_tlsConnIndex = SSL_get_ex_new_index(0, nullptr, nullptr, nullptr, nullptr);
562 if (s_tlsConnIndex == -1) {
563 throw std::runtime_error("Error getting an index for TLS connection data");
564 }
565
566 *init = true;
567 }
568
569 static int getConnectionIndex()
570 {
571 return s_tlsConnIndex;
572 }
573
574 private:
575 static LockGuarded<bool> s_initTLSConnIndex;
576 static int s_tlsConnIndex;
577 std::vector<std::unique_ptr<TLSSession>> d_tlsSessions;
578 /* server context */
579 std::shared_ptr<OpenSSLFrontendContext> d_feContext;
580 /* client context */
581 std::shared_ptr<SSL_CTX> d_tlsCtx;
582 std::unique_ptr<SSL, void(*)(SSL*)> d_conn;
583 std::string d_hostname;
584 struct timeval d_timeout;
585 bool d_connected{false};
586 bool d_ktls{false};
587 };
588
589 LockGuarded<bool> OpenSSLTLSConnection::s_initTLSConnIndex{false};
590 int OpenSSLTLSConnection::s_tlsConnIndex{-1};
591
592 class OpenSSLTLSIOCtx: public TLSCtx
593 {
594 public:
595 /* server side context */
596 OpenSSLTLSIOCtx(TLSFrontend& fe): d_feContext(std::make_shared<OpenSSLFrontendContext>(fe.d_addr, fe.d_tlsConfig))
597 {
598 OpenSSLTLSConnection::generateConnectionIndexIfNeeded();
599
600 d_ticketsKeyRotationDelay = fe.d_tlsConfig.d_ticketsKeyRotationDelay;
601
602 if (fe.d_tlsConfig.d_enableTickets && fe.d_tlsConfig.d_numberOfTicketsKeys > 0) {
603 /* use our own ticket keys handler so we can rotate them */
604 #if OPENSSL_VERSION_MAJOR >= 3
605 SSL_CTX_set_tlsext_ticket_key_evp_cb(d_feContext->d_tlsCtx.get(), &OpenSSLTLSIOCtx::ticketKeyCb);
606 #else
607 SSL_CTX_set_tlsext_ticket_key_cb(d_feContext->d_tlsCtx.get(), &OpenSSLTLSIOCtx::ticketKeyCb);
608 #endif
609 libssl_set_ticket_key_callback_data(d_feContext->d_tlsCtx.get(), d_feContext.get());
610 }
611
612 #ifndef DISABLE_OCSP_STAPLING
613 if (!d_feContext->d_ocspResponses.empty()) {
614 SSL_CTX_set_tlsext_status_cb(d_feContext->d_tlsCtx.get(), &OpenSSLTLSIOCtx::ocspStaplingCb);
615 SSL_CTX_set_tlsext_status_arg(d_feContext->d_tlsCtx.get(), &d_feContext->d_ocspResponses);
616 }
617 #endif /* DISABLE_OCSP_STAPLING */
618
619 if (fe.d_tlsConfig.d_readAhead) {
620 SSL_CTX_set_read_ahead(d_feContext->d_tlsCtx.get(), 1);
621 }
622
623 libssl_set_error_counters_callback(d_feContext->d_tlsCtx, &fe.d_tlsCounters);
624
625 if (!fe.d_tlsConfig.d_keyLogFile.empty()) {
626 d_feContext->d_keyLogFile = libssl_set_key_log_file(d_feContext->d_tlsCtx, fe.d_tlsConfig.d_keyLogFile);
627 }
628
629 try {
630 if (fe.d_tlsConfig.d_ticketKeyFile.empty()) {
631 handleTicketsKeyRotation(time(nullptr));
632 }
633 else {
634 OpenSSLTLSIOCtx::loadTicketsKeys(fe.d_tlsConfig.d_ticketKeyFile);
635 }
636 }
637 catch (const std::exception& e) {
638 throw;
639 }
640 }
641
642 /* client side context */
643 OpenSSLTLSIOCtx(const TLSContextParameters& params)
644 {
645 int sslOptions =
646 SSL_OP_NO_SSLv2 |
647 SSL_OP_NO_SSLv3 |
648 SSL_OP_NO_COMPRESSION |
649 SSL_OP_NO_SESSION_RESUMPTION_ON_RENEGOTIATION |
650 SSL_OP_SINGLE_DH_USE |
651 SSL_OP_SINGLE_ECDH_USE |
652 #ifdef SSL_OP_IGNORE_UNEXPECTED_EOF
653 SSL_OP_IGNORE_UNEXPECTED_EOF |
654 #endif
655 SSL_OP_CIPHER_SERVER_PREFERENCE;
656 if (!params.d_enableRenegotiation) {
657 #ifdef SSL_OP_NO_RENEGOTIATION
658 sslOptions |= SSL_OP_NO_RENEGOTIATION;
659 #elif defined(SSL_OP_NO_CLIENT_RENEGOTIATION)
660 sslOptions |= SSL_OP_NO_CLIENT_RENEGOTIATION;
661 #endif
662 }
663
664 if (params.d_ktls) {
665 #ifdef SSL_OP_ENABLE_KTLS
666 sslOptions |= SSL_OP_ENABLE_KTLS;
667 d_ktls = true;
668 #endif /* SSL_OP_ENABLE_KTLS */
669 }
670
671 registerOpenSSLUser();
672
673 OpenSSLTLSConnection::generateConnectionIndexIfNeeded();
674
675 #ifdef HAVE_TLS_CLIENT_METHOD
676 d_tlsCtx = std::shared_ptr<SSL_CTX>(SSL_CTX_new(TLS_client_method()), SSL_CTX_free);
677 #else
678 d_tlsCtx = std::shared_ptr<SSL_CTX>(SSL_CTX_new(SSLv23_client_method()), SSL_CTX_free);
679 #endif
680 if (!d_tlsCtx) {
681 ERR_print_errors_fp(stderr);
682 throw std::runtime_error("Error creating TLS context");
683 }
684
685 SSL_CTX_set_options(d_tlsCtx.get(), sslOptions);
686 #if defined(SSL_CTX_set_ecdh_auto)
687 SSL_CTX_set_ecdh_auto(d_tlsCtx.get(), 1);
688 #endif
689
690 if (!params.d_ciphers.empty()) {
691 if (SSL_CTX_set_cipher_list(d_tlsCtx.get(), params.d_ciphers.c_str()) != 1) {
692 ERR_print_errors_fp(stderr);
693 throw std::runtime_error("Error setting the cipher list to '" + params.d_ciphers + "' for the TLS context");
694 }
695 }
696 #ifdef HAVE_SSL_CTX_SET_CIPHERSUITES
697 if (!params.d_ciphers13.empty()) {
698 if (SSL_CTX_set_ciphersuites(d_tlsCtx.get(), params.d_ciphers13.c_str()) != 1) {
699 ERR_print_errors_fp(stderr);
700 throw std::runtime_error("Error setting the TLS 1.3 cipher list to '" + params.d_ciphers13 + "' for the TLS context");
701 }
702 }
703 #endif /* HAVE_SSL_CTX_SET_CIPHERSUITES */
704
705 if (params.d_validateCertificates) {
706 if (params.d_caStore.empty()) {
707 if (SSL_CTX_set_default_verify_paths(d_tlsCtx.get()) != 1) {
708 throw std::runtime_error("Error adding the system's default trusted CAs");
709 }
710 } else {
711 if (SSL_CTX_load_verify_locations(d_tlsCtx.get(), params.d_caStore.c_str(), nullptr) != 1) {
712 throw std::runtime_error("Error adding the trusted CAs file " + params.d_caStore);
713 }
714 }
715
716 SSL_CTX_set_verify(d_tlsCtx.get(), SSL_VERIFY_PEER, nullptr);
717 #if (OPENSSL_VERSION_NUMBER < 0x10002000L)
718 warnlog("TLS hostname validation requested but not supported for OpenSSL < 1.0.2");
719 #endif
720 }
721
722 /* we need to set SSL_SESS_CACHE_CLIENT for the "new ticket" callback (below) to be called,
723 but we don't want OpenSSL to cache the session itself so we set SSL_SESS_CACHE_NO_INTERNAL_STORE as well */
724 SSL_CTX_set_session_cache_mode(d_tlsCtx.get(), SSL_SESS_CACHE_CLIENT | SSL_SESS_CACHE_NO_INTERNAL_STORE);
725 SSL_CTX_sess_set_new_cb(d_tlsCtx.get(), &OpenSSLTLSIOCtx::newTicketFromServerCb);
726
727 #ifdef SSL_MODE_RELEASE_BUFFERS
728 if (params.d_releaseBuffers) {
729 SSL_CTX_set_mode(d_tlsCtx.get(), SSL_MODE_RELEASE_BUFFERS);
730 }
731 #endif
732 }
733
734 ~OpenSSLTLSIOCtx() override
735 {
736 d_tlsCtx.reset();
737 unregisterOpenSSLUser();
738 }
739
740 #if OPENSSL_VERSION_MAJOR >= 3
741 static int ticketKeyCb(SSL* s, unsigned char keyName[TLS_TICKETS_KEY_NAME_SIZE], unsigned char* iv, EVP_CIPHER_CTX* ectx, EVP_MAC_CTX* hctx, int enc)
742 #else
743 static int ticketKeyCb(SSL* s, unsigned char keyName[TLS_TICKETS_KEY_NAME_SIZE], unsigned char* iv, EVP_CIPHER_CTX* ectx, HMAC_CTX* hctx, int enc)
744 #endif
745 {
746 auto* ctx = reinterpret_cast<OpenSSLFrontendContext*>(libssl_get_ticket_key_callback_data(s));
747 if (ctx == nullptr) {
748 return -1;
749 }
750
751 int ret = libssl_ticket_key_callback(s, ctx->d_ticketKeys, keyName, iv, ectx, hctx, enc);
752 if (enc == 0) {
753 if (ret == 0 || ret == 2) {
754 auto* conn = reinterpret_cast<OpenSSLTLSConnection*>(SSL_get_ex_data(s, OpenSSLTLSConnection::getConnectionIndex()));
755 if (conn != nullptr) {
756 if (ret == 0) {
757 conn->setUnknownTicketKey();
758 }
759 else if (ret == 2) {
760 conn->setResumedFromInactiveTicketKey();
761 }
762 }
763 }
764 }
765
766 return ret;
767 }
768
769 #ifndef DISABLE_OCSP_STAPLING
770 static int ocspStaplingCb(SSL* ssl, void* arg)
771 {
772 if (ssl == nullptr || arg == nullptr) {
773 return SSL_TLSEXT_ERR_NOACK;
774 }
775 const auto ocspMap = reinterpret_cast<std::map<int, std::string>*>(arg);
776 return libssl_ocsp_stapling_callback(ssl, *ocspMap);
777 }
778 #endif /* DISABLE_OCSP_STAPLING */
779
780 static int newTicketFromServerCb(SSL* ssl, SSL_SESSION* session)
781 {
782 OpenSSLTLSConnection* conn = reinterpret_cast<OpenSSLTLSConnection*>(SSL_get_ex_data(ssl, OpenSSLTLSConnection::getConnectionIndex()));
783 if (session == nullptr || conn == nullptr) {
784 return 0;
785 }
786
787 conn->addNewTicket(session);
788 return 1;
789 }
790
791 std::unique_ptr<TLSConnection> getConnection(int socket, const struct timeval& timeout, time_t now) override
792 {
793 handleTicketsKeyRotation(now);
794
795 return std::make_unique<OpenSSLTLSConnection>(socket, timeout, d_feContext);
796 }
797
798 std::unique_ptr<TLSConnection> getClientConnection(const std::string& host, bool hostIsAddr, int socket, const struct timeval& timeout) override
799 {
800 auto conn = std::make_unique<OpenSSLTLSConnection>(host, hostIsAddr, socket, timeout, d_tlsCtx);
801 if (d_ktls) {
802 conn->enableKTLS();
803 }
804 return conn;
805 }
806
807 void rotateTicketsKey(time_t now) override
808 {
809 d_feContext->d_ticketKeys.rotateTicketsKey(now);
810
811 if (d_ticketsKeyRotationDelay > 0) {
812 d_ticketsKeyNextRotation = now + d_ticketsKeyRotationDelay;
813 }
814 }
815
816 void loadTicketsKeys(const std::string& keyFile) final
817 {
818 d_feContext->d_ticketKeys.loadTicketsKeys(keyFile);
819
820 if (d_ticketsKeyRotationDelay > 0) {
821 d_ticketsKeyNextRotation = time(nullptr) + d_ticketsKeyRotationDelay;
822 }
823 }
824
825 size_t getTicketsKeysCount() override
826 {
827 return d_feContext->d_ticketKeys.getKeysCount();
828 }
829
830 std::string getName() const override
831 {
832 return "openssl";
833 }
834
835 bool setALPNProtos(const std::vector<std::vector<uint8_t>>& protos) override
836 {
837 if (d_feContext && d_feContext->d_tlsCtx) {
838 d_alpnProtos = protos;
839 libssl_set_alpn_select_callback(d_feContext->d_tlsCtx.get(), alpnServerSelectCallback, this);
840 return true;
841 }
842 if (d_tlsCtx) {
843 return libssl_set_alpn_protos(d_tlsCtx.get(), protos);
844 }
845 return false;
846 }
847
848 #ifndef DISABLE_NPN
849 bool setNextProtocolSelectCallback(bool(*cb)(unsigned char** out, unsigned char* outlen, const unsigned char* in, unsigned int inlen)) override
850 {
851 d_nextProtocolSelectCallback = cb;
852 libssl_set_npn_select_callback(d_tlsCtx.get(), npnSelectCallback, this);
853 return true;
854 }
855 #endif /* DISABLE_NPN */
856
857 private:
858 /* called in a client context, if the client advertised more than one ALPN values and the server returned more than one as well, to select the one to use. */
859 #ifndef DISABLE_NPN
860 static int npnSelectCallback(SSL* /* s */, unsigned char** out, unsigned char* outlen, const unsigned char* in, unsigned int inlen, void* arg)
861 {
862 if (!arg) {
863 return SSL_TLSEXT_ERR_ALERT_WARNING;
864 }
865 OpenSSLTLSIOCtx* obj = reinterpret_cast<OpenSSLTLSIOCtx*>(arg);
866 if (obj->d_nextProtocolSelectCallback) {
867 return (*obj->d_nextProtocolSelectCallback)(out, outlen, in, inlen) ? SSL_TLSEXT_ERR_OK : SSL_TLSEXT_ERR_ALERT_WARNING;
868 }
869
870 return SSL_TLSEXT_ERR_OK;
871 }
872 #endif /* NPN */
873
874 static int alpnServerSelectCallback(SSL*, const unsigned char** out, unsigned char* outlen, const unsigned char* in, unsigned int inlen, void* arg)
875 {
876 if (!arg) {
877 return SSL_TLSEXT_ERR_ALERT_WARNING;
878 }
879 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast): OpenSSL's API
880 OpenSSLTLSIOCtx* obj = reinterpret_cast<OpenSSLTLSIOCtx*>(arg);
881
882 const pdns::views::UnsignedCharView inView(in, inlen);
883 // Server preference algorithm as per RFC 7301 section 3.2
884 for (const auto& tentative : obj->d_alpnProtos) {
885 size_t pos = 0;
886 while (pos < inView.size()) {
887 size_t protoLen = inView.at(pos);
888 pos++;
889 if (protoLen > (inlen - pos)) {
890 /* something is very wrong */
891 return SSL_TLSEXT_ERR_ALERT_WARNING;
892 }
893
894 if (tentative.size() == protoLen && memcmp(&inView.at(pos), tentative.data(), tentative.size()) == 0) {
895 *out = &inView.at(pos);
896 *outlen = protoLen;
897 return SSL_TLSEXT_ERR_OK;
898 }
899 pos += protoLen;
900 }
901 }
902
903 return SSL_TLSEXT_ERR_NOACK;
904 }
905
906 std::vector<std::vector<uint8_t>> d_alpnProtos; // store the supported ALPN protocols, so that the server can select based on what the client sent
907 std::shared_ptr<OpenSSLFrontendContext> d_feContext{nullptr};
908 std::shared_ptr<SSL_CTX> d_tlsCtx{nullptr}; // client context, on a server-side the context is stored in d_feContext->d_tlsCtx
909 bool (*d_nextProtocolSelectCallback)(unsigned char** out, unsigned char* outlen, const unsigned char* in, unsigned int inlen){nullptr};
910 bool d_ktls{false};
911 };
912
913 #endif /* HAVE_LIBSSL */
914
915 #ifdef HAVE_GNUTLS
916 #include <gnutls/gnutls.h>
917 #include <gnutls/x509.h>
918
919 static void safe_memory_lock(void* data, size_t size)
920 {
921 #ifdef HAVE_LIBSODIUM
922 sodium_mlock(data, size);
923 #endif
924 }
925
926 static void safe_memory_release(void* data, size_t size)
927 {
928 #ifdef HAVE_LIBSODIUM
929 sodium_munlock(data, size);
930 #elif defined(HAVE_EXPLICIT_BZERO)
931 explicit_bzero(data, size);
932 #elif defined(HAVE_EXPLICIT_MEMSET)
933 explicit_memset(data, 0, size);
934 #elif defined(HAVE_GNUTLS_MEMSET)
935 gnutls_memset(data, 0, size);
936 #else
937 /* shamelessly taken from Dovecot's src/lib/safe-memset.c */
938 volatile unsigned int volatile_zero_idx = 0;
939 volatile unsigned char *p = reinterpret_cast<volatile unsigned char *>(data);
940
941 if (size == 0)
942 return;
943
944 do {
945 memset(data, 0, size);
946 } while (p[volatile_zero_idx] != 0);
947 #endif
948 }
949
950 class GnuTLSTicketsKey
951 {
952 public:
953 GnuTLSTicketsKey()
954 {
955 if (gnutls_session_ticket_key_generate(&d_key) != GNUTLS_E_SUCCESS) {
956 throw std::runtime_error("Error generating tickets key for TLS context");
957 }
958
959 safe_memory_lock(d_key.data, d_key.size);
960 }
961
962 GnuTLSTicketsKey(const std::string& keyFile)
963 {
964 /* to be sure we are loading the correct amount of data, which
965 may change between versions, let's generate a correct key first */
966 if (gnutls_session_ticket_key_generate(&d_key) != GNUTLS_E_SUCCESS) {
967 throw std::runtime_error("Error generating tickets key (before parsing key file) for TLS context");
968 }
969
970 safe_memory_lock(d_key.data, d_key.size);
971
972 try {
973 ifstream file(keyFile);
974 file.read(reinterpret_cast<char*>(d_key.data), d_key.size);
975
976 if (file.fail()) {
977 file.close();
978 throw std::runtime_error("Invalid GnuTLS tickets key file " + keyFile);
979 }
980
981 file.close();
982 }
983 catch (const std::exception& e) {
984 safe_memory_release(d_key.data, d_key.size);
985 gnutls_free(d_key.data);
986 d_key.data = nullptr;
987 throw;
988 }
989 }
990
991 ~GnuTLSTicketsKey()
992 {
993 if (d_key.data != nullptr && d_key.size > 0) {
994 safe_memory_release(d_key.data, d_key.size);
995 }
996 gnutls_free(d_key.data);
997 d_key.data = nullptr;
998 }
999 const gnutls_datum_t& getKey() const
1000 {
1001 return d_key;
1002 }
1003
1004 private:
1005 gnutls_datum_t d_key{nullptr, 0};
1006 };
1007
1008 class GnuTLSSession : public TLSSession
1009 {
1010 public:
1011 GnuTLSSession(gnutls_datum_t& sess): d_sess(sess)
1012 {
1013 sess.data = nullptr;
1014 sess.size = 0;
1015 }
1016
1017 ~GnuTLSSession() override
1018 {
1019 if (d_sess.data != nullptr && d_sess.size > 0) {
1020 safe_memory_release(d_sess.data, d_sess.size);
1021 }
1022 gnutls_free(d_sess.data);
1023 d_sess.data = nullptr;
1024 }
1025
1026 const gnutls_datum_t& getNative()
1027 {
1028 return d_sess;
1029 }
1030
1031 private:
1032 gnutls_datum_t d_sess{nullptr, 0};
1033 };
1034
1035 class GnuTLSConnection: public TLSConnection
1036 {
1037 public:
1038 /* server side connection */
1039 GnuTLSConnection(int socket, const struct timeval& timeout, std::shared_ptr<gnutls_certificate_credentials_st>& creds, const gnutls_priority_t priorityCache, std::shared_ptr<GnuTLSTicketsKey>& ticketsKey, bool enableTickets): d_creds(creds), d_ticketsKey(ticketsKey), d_conn(std::unique_ptr<gnutls_session_int, void(*)(gnutls_session_t)>(nullptr, gnutls_deinit))
1040 {
1041 unsigned int sslOptions = GNUTLS_SERVER | GNUTLS_NONBLOCK;
1042 #ifdef GNUTLS_NO_SIGNAL
1043 sslOptions |= GNUTLS_NO_SIGNAL;
1044 #endif
1045
1046 d_socket = socket;
1047
1048 gnutls_session_t conn;
1049 if (gnutls_init(&conn, sslOptions) != GNUTLS_E_SUCCESS) {
1050 throw std::runtime_error("Error creating TLS connection");
1051 }
1052
1053 d_conn = std::unique_ptr<gnutls_session_int, void(*)(gnutls_session_t)>(conn, gnutls_deinit);
1054 conn = nullptr;
1055
1056 if (gnutls_credentials_set(d_conn.get(), GNUTLS_CRD_CERTIFICATE, d_creds.get()) != GNUTLS_E_SUCCESS) {
1057 throw std::runtime_error("Error setting certificate and key to TLS connection");
1058 }
1059
1060 if (gnutls_priority_set(d_conn.get(), priorityCache) != GNUTLS_E_SUCCESS) {
1061 throw std::runtime_error("Error setting ciphers to TLS connection");
1062 }
1063
1064 if (enableTickets && d_ticketsKey) {
1065 const gnutls_datum_t& key = d_ticketsKey->getKey();
1066 if (gnutls_session_ticket_enable_server(d_conn.get(), &key) != GNUTLS_E_SUCCESS) {
1067 throw std::runtime_error("Error setting the tickets key to TLS connection");
1068 }
1069 }
1070
1071 gnutls_transport_set_int(d_conn.get(), d_socket);
1072
1073 /* timeouts are in milliseconds */
1074 gnutls_handshake_set_timeout(d_conn.get(), timeout.tv_sec * 1000 + timeout.tv_usec / 1000);
1075 gnutls_record_set_timeout(d_conn.get(), timeout.tv_sec * 1000 + timeout.tv_usec / 1000);
1076 }
1077
1078 /* client-side connection */
1079 GnuTLSConnection(const std::string& host, int socket, const struct timeval& timeout, std::shared_ptr<gnutls_certificate_credentials_st>& creds, const gnutls_priority_t priorityCache, bool validateCerts): d_creds(creds), d_conn(std::unique_ptr<gnutls_session_int, void(*)(gnutls_session_t)>(nullptr, gnutls_deinit)), d_host(host), d_client(true)
1080 {
1081 unsigned int sslOptions = GNUTLS_CLIENT | GNUTLS_NONBLOCK;
1082 #ifdef GNUTLS_NO_SIGNAL
1083 sslOptions |= GNUTLS_NO_SIGNAL;
1084 #endif
1085
1086 d_socket = socket;
1087
1088 gnutls_session_t conn;
1089 if (gnutls_init(&conn, sslOptions) != GNUTLS_E_SUCCESS) {
1090 throw std::runtime_error("Error creating TLS connection");
1091 }
1092
1093 d_conn = std::unique_ptr<gnutls_session_int, void(*)(gnutls_session_t)>(conn, gnutls_deinit);
1094 conn = nullptr;
1095
1096 int rc = gnutls_credentials_set(d_conn.get(), GNUTLS_CRD_CERTIFICATE, d_creds.get());
1097 if (rc != GNUTLS_E_SUCCESS) {
1098 throw std::runtime_error("Error setting certificate and key to TLS connection: " + std::string(gnutls_strerror(rc)));
1099 }
1100
1101 rc = gnutls_priority_set(d_conn.get(), priorityCache);
1102 if (rc != GNUTLS_E_SUCCESS) {
1103 throw std::runtime_error("Error setting ciphers to TLS connection: " + std::string(gnutls_strerror(rc)));
1104 }
1105
1106 gnutls_transport_set_int(d_conn.get(), d_socket);
1107
1108 /* timeouts are in milliseconds */
1109 gnutls_handshake_set_timeout(d_conn.get(), timeout.tv_sec * 1000 + timeout.tv_usec / 1000);
1110 gnutls_record_set_timeout(d_conn.get(), timeout.tv_sec * 1000 + timeout.tv_usec / 1000);
1111
1112 #ifdef HAVE_GNUTLS_SESSION_SET_VERIFY_CERT
1113 if (validateCerts && !d_host.empty()) {
1114 gnutls_session_set_verify_cert(d_conn.get(), d_host.c_str(), GNUTLS_VERIFY_ALLOW_UNSORTED_CHAIN);
1115 rc = gnutls_server_name_set(d_conn.get(), GNUTLS_NAME_DNS, d_host.c_str(), d_host.size());
1116 if (rc != GNUTLS_E_SUCCESS) {
1117 throw std::runtime_error("Error setting the SNI value to '" + d_host + "' on TLS connection: " + std::string(gnutls_strerror(rc)));
1118 }
1119 }
1120 #else
1121 /* no hostname validation for you */
1122 #endif
1123
1124 /* allow access to our data in the callbacks */
1125 gnutls_session_set_ptr(d_conn.get(), this);
1126 gnutls_handshake_set_hook_function(d_conn.get(), GNUTLS_HANDSHAKE_NEW_SESSION_TICKET, GNUTLS_HOOK_POST, newTicketFromServerCb);
1127 }
1128
1129 /* The callback prototype changed in 3.4.0. */
1130 #if GNUTLS_VERSION_NUMBER >= 0x030400
1131 static int newTicketFromServerCb(gnutls_session_t session, unsigned int htype, unsigned post, unsigned int /* incoming */, const gnutls_datum_t* /* msg */)
1132 #else
1133 static int newTicketFromServerCb(gnutls_session_t session, unsigned int htype, unsigned post, unsigned int /* incoming */)
1134 #endif /* GNUTLS_VERSION_NUMBER >= 0x030400 */
1135 {
1136 if (htype != GNUTLS_HANDSHAKE_NEW_SESSION_TICKET || post != GNUTLS_HOOK_POST || session == nullptr) {
1137 return 0;
1138 }
1139
1140 GnuTLSConnection* conn = reinterpret_cast<GnuTLSConnection*>(gnutls_session_get_ptr(session));
1141 if (conn == nullptr) {
1142 return 0;
1143 }
1144
1145 gnutls_datum_t sess{nullptr, 0};
1146 auto ret = gnutls_session_get_data2(session, &sess);
1147 /* GnuTLS returns a 'fake' ticket of 4 bytes set to zero when there is no ticket available */
1148 if (ret != GNUTLS_E_SUCCESS || sess.size <= 4) {
1149 throw std::runtime_error("Error getting GnuTLSSession: " + std::string(gnutls_strerror(ret)));
1150 }
1151 conn->d_tlsSessions.push_back(std::make_unique<GnuTLSSession>(sess));
1152 return 0;
1153 }
1154
1155 IOState tryConnect(bool fastOpen, [[maybe_unused]] const ComboAddress& remote) override
1156 {
1157 int ret = 0;
1158
1159 if (fastOpen) {
1160 #ifdef HAVE_GNUTLS_TRANSPORT_SET_FASTOPEN
1161 gnutls_transport_set_fastopen(d_conn.get(), d_socket, const_cast<struct sockaddr*>(reinterpret_cast<const struct sockaddr*>(&remote)), remote.getSocklen(), 0);
1162 #endif
1163 }
1164
1165 do {
1166 ret = gnutls_handshake(d_conn.get());
1167 if (ret == GNUTLS_E_SUCCESS) {
1168 d_handshakeDone = true;
1169 return IOState::Done;
1170 }
1171 else if (ret == GNUTLS_E_AGAIN) {
1172 int direction = gnutls_record_get_direction(d_conn.get());
1173 return direction == 0 ? IOState::NeedRead : IOState::NeedWrite;
1174 }
1175 else if (gnutls_error_is_fatal(ret) || ret == GNUTLS_E_WARNING_ALERT_RECEIVED) {
1176 throw std::runtime_error("Error establishing a new connection: " + std::string(gnutls_strerror(ret)));
1177 }
1178 } while (ret == GNUTLS_E_INTERRUPTED);
1179
1180 throw std::runtime_error("Error establishing a new connection: " + std::string(gnutls_strerror(ret)));
1181 }
1182
1183 void connect(bool fastOpen, const ComboAddress& remote, const struct timeval& timeout) override
1184 {
1185 struct timeval start = {0, 0};
1186 struct timeval remainingTime = timeout;
1187 if (timeout.tv_sec != 0 || timeout.tv_usec != 0) {
1188 gettimeofday(&start, nullptr);
1189 }
1190
1191 IOState state;
1192 do {
1193 state = tryConnect(fastOpen, remote);
1194 if (state == IOState::Done) {
1195 return;
1196 }
1197 else if (state == IOState::NeedRead) {
1198 int result = waitForData(d_socket, remainingTime.tv_sec, remainingTime.tv_usec);
1199 if (result <= 0) {
1200 throw std::runtime_error("Error reading from TLS connection: " + std::to_string(result));
1201 }
1202 }
1203 else if (state == IOState::NeedWrite) {
1204 int result = waitForRWData(d_socket, false, remainingTime.tv_sec, remainingTime.tv_usec);
1205 if (result <= 0) {
1206 throw std::runtime_error("Error reading from TLS connection: " + std::to_string(result));
1207 }
1208 }
1209
1210 if (timeout.tv_sec != 0 || timeout.tv_usec != 0) {
1211 struct timeval now;
1212 gettimeofday(&now, nullptr);
1213 struct timeval elapsed = now - start;
1214 if (now < start || remainingTime < elapsed) {
1215 throw runtime_error("Timeout while establishing TLS connection");
1216 }
1217 start = now;
1218 remainingTime = remainingTime - elapsed;
1219 }
1220 }
1221 while (state != IOState::Done);
1222 }
1223
1224 void doHandshake() override
1225 {
1226 int ret = 0;
1227 do {
1228 ret = gnutls_handshake(d_conn.get());
1229 if (gnutls_error_is_fatal(ret) || ret == GNUTLS_E_WARNING_ALERT_RECEIVED) {
1230 if (d_client) {
1231 throw std::runtime_error("Error establishing a new connection: " + std::string(gnutls_strerror(ret)));
1232 }
1233 else {
1234 throw std::runtime_error("Error accepting a new connection: " + std::string(gnutls_strerror(ret)));
1235 }
1236 }
1237 }
1238 while (ret != GNUTLS_E_SUCCESS && ret == GNUTLS_E_INTERRUPTED);
1239
1240 d_handshakeDone = true;
1241 }
1242
1243 IOState tryHandshake() override
1244 {
1245 int ret = 0;
1246
1247 do {
1248 ret = gnutls_handshake(d_conn.get());
1249 if (ret == GNUTLS_E_SUCCESS) {
1250 d_handshakeDone = true;
1251 return IOState::Done;
1252 }
1253 else if (ret == GNUTLS_E_AGAIN) {
1254 int direction = gnutls_record_get_direction(d_conn.get());
1255 return direction == 0 ? IOState::NeedRead : IOState::NeedWrite;
1256 }
1257 else if (gnutls_error_is_fatal(ret) || ret == GNUTLS_E_WARNING_ALERT_RECEIVED) {
1258 if (d_client) {
1259 std::string error;
1260 #ifdef HAVE_GNUTLS_SESSION_GET_VERIFY_CERT_STATUS
1261 if (ret == GNUTLS_E_CERTIFICATE_VERIFICATION_ERROR) {
1262 gnutls_datum_t out;
1263 if (gnutls_certificate_verification_status_print(gnutls_session_get_verify_cert_status(d_conn.get()), gnutls_certificate_type_get(d_conn.get()), &out, 0) == 0) {
1264 error = " (" + std::string(reinterpret_cast<const char*>(out.data)) + ")";
1265 gnutls_free(out.data);
1266 }
1267 }
1268 #endif /* HAVE_GNUTLS_SESSION_GET_VERIFY_CERT_STATUS */
1269 throw std::runtime_error("Error accepting a new connection: " + std::string(gnutls_strerror(ret)) + error);
1270 }
1271 else {
1272 throw std::runtime_error("Error establishing a new connection: " + std::string(gnutls_strerror(ret)));
1273 }
1274 }
1275 } while (ret == GNUTLS_E_INTERRUPTED);
1276
1277 if (d_client) {
1278 throw std::runtime_error("Error establishinging a new connection: " + std::string(gnutls_strerror(ret)));
1279 }
1280 else {
1281 throw std::runtime_error("Error accepting a new connection: " + std::string(gnutls_strerror(ret)));
1282 }
1283 }
1284
1285 IOState tryWrite(const PacketBuffer& buffer, size_t& pos, size_t toWrite) override
1286 {
1287 if (!d_handshakeDone) {
1288 /* As opposed to OpenSSL, GnuTLS will not transparently finish the handshake for us,
1289 we need to keep calling gnutls_handshake() until the handshake has been finished. */
1290 auto state = tryHandshake();
1291 if (state != IOState::Done) {
1292 return state;
1293 }
1294 }
1295
1296 do {
1297 ssize_t res = gnutls_record_send(d_conn.get(), reinterpret_cast<const char *>(&buffer.at(pos)), toWrite - pos);
1298 if (res == 0) {
1299 throw std::runtime_error("Error writing to TLS connection");
1300 }
1301 else if (res > 0) {
1302 pos += static_cast<size_t>(res);
1303 }
1304 else if (res < 0) {
1305 if (gnutls_error_is_fatal(res)) {
1306 throw std::runtime_error("Fatal error writing to TLS connection: " + std::string(gnutls_strerror(res)));
1307 }
1308 else if (res == GNUTLS_E_AGAIN) {
1309 return IOState::NeedWrite;
1310 }
1311 vinfolog("Warning, non-fatal error while writing to TLS connection: %s", gnutls_strerror(res));
1312 }
1313 }
1314 while (pos < toWrite);
1315 return IOState::Done;
1316 }
1317
1318 IOState tryRead(PacketBuffer& buffer, size_t& pos, size_t toRead, bool allowIncomplete) override
1319 {
1320 if (!d_handshakeDone) {
1321 /* As opposed to OpenSSL, GnuTLS will not transparently finish the handshake for us,
1322 we need to keep calling gnutls_handshake() until the handshake has been finished. */
1323 auto state = tryHandshake();
1324 if (state != IOState::Done) {
1325 return state;
1326 }
1327 }
1328
1329 do {
1330 ssize_t res = gnutls_record_recv(d_conn.get(), reinterpret_cast<char *>(&buffer.at(pos)), toRead - pos);
1331 if (res == 0) {
1332 throw std::runtime_error("EOF while reading from TLS connection");
1333 }
1334 else if (res > 0) {
1335 pos += static_cast<size_t>(res);
1336 if (allowIncomplete) {
1337 break;
1338 }
1339 }
1340 else if (res < 0) {
1341 if (gnutls_error_is_fatal(res)) {
1342 throw std::runtime_error("Fatal error reading from TLS connection: " + std::string(gnutls_strerror(res)));
1343 }
1344 else if (res == GNUTLS_E_AGAIN) {
1345 return IOState::NeedRead;
1346 }
1347 vinfolog("Warning, non-fatal error while writing to TLS connection: %s", gnutls_strerror(res));
1348 }
1349 }
1350 while (pos < toRead);
1351 return IOState::Done;
1352 }
1353
1354 size_t read(void* buffer, size_t bufferSize, const struct timeval& readTimeout, const struct timeval& totalTimeout, bool allowIncomplete) override
1355 {
1356 size_t got = 0;
1357 struct timeval start{0,0};
1358 struct timeval remainingTime = totalTimeout;
1359 if (totalTimeout.tv_sec != 0 || totalTimeout.tv_usec != 0) {
1360 gettimeofday(&start, nullptr);
1361 }
1362
1363 do {
1364 ssize_t res = gnutls_record_recv(d_conn.get(), (reinterpret_cast<char *>(buffer) + got), bufferSize - got);
1365 if (res == 0) {
1366 throw std::runtime_error("EOF while reading from TLS connection");
1367 }
1368 else if (res > 0) {
1369 got += static_cast<size_t>(res);
1370 if (allowIncomplete) {
1371 break;
1372 }
1373 }
1374 else if (res < 0) {
1375 if (gnutls_error_is_fatal(res)) {
1376 throw std::runtime_error("Fatal error reading from TLS connection: " + std::string(gnutls_strerror(res)));
1377 }
1378 else if (res == GNUTLS_E_AGAIN) {
1379 int result = waitForData(d_socket, readTimeout.tv_sec, readTimeout.tv_usec);
1380 if (result <= 0) {
1381 throw std::runtime_error("Error while waiting to read from TLS connection: " + std::to_string(result));
1382 }
1383 }
1384 else {
1385 vinfolog("Non-fatal error while reading from TLS connection: %s", gnutls_strerror(res));
1386 }
1387 }
1388
1389 if (totalTimeout.tv_sec != 0 || totalTimeout.tv_usec != 0) {
1390 struct timeval now;
1391 gettimeofday(&now, nullptr);
1392 struct timeval elapsed = now - start;
1393 if (now < start || remainingTime < elapsed) {
1394 throw runtime_error("Timeout while reading data");
1395 }
1396 start = now;
1397 remainingTime = remainingTime - elapsed;
1398 }
1399 }
1400 while (got < bufferSize);
1401
1402 return got;
1403 }
1404
1405 size_t write(const void* buffer, size_t bufferSize, const struct timeval& writeTimeout) override
1406 {
1407 size_t got = 0;
1408
1409 do {
1410 ssize_t res = gnutls_record_send(d_conn.get(), (reinterpret_cast<const char *>(buffer) + got), bufferSize - got);
1411 if (res == 0) {
1412 throw std::runtime_error("Error writing to TLS connection");
1413 }
1414 else if (res > 0) {
1415 got += static_cast<size_t>(res);
1416 }
1417 else if (res < 0) {
1418 if (gnutls_error_is_fatal(res)) {
1419 throw std::runtime_error("Fatal error writing to TLS connection: " + std::string(gnutls_strerror(res)));
1420 }
1421 else if (res == GNUTLS_E_AGAIN) {
1422 int result = waitForRWData(d_socket, false, writeTimeout.tv_sec, writeTimeout.tv_usec);
1423 if (result <= 0) {
1424 throw std::runtime_error("Error waiting to write to TLS connection: " + std::to_string(result));
1425 }
1426 }
1427 else {
1428 vinfolog("Non-fatal error while writing to TLS connection: %s", gnutls_strerror(res));
1429 }
1430 }
1431 }
1432 while (got < bufferSize);
1433
1434 return got;
1435 }
1436
1437 bool isUsable() const override
1438 {
1439 if (!d_conn) {
1440 return false;
1441 }
1442
1443 /* as far as I can tell we can't peek so we cannot do better */
1444 return isTCPSocketUsable(d_socket);
1445 }
1446
1447 std::string getServerNameIndication() const override
1448 {
1449 if (d_conn) {
1450 unsigned int type;
1451 size_t name_len = 256;
1452 std::string sni;
1453 sni.resize(name_len);
1454
1455 int res = gnutls_server_name_get(d_conn.get(), const_cast<char*>(sni.c_str()), &name_len, &type, 0);
1456 if (res == GNUTLS_E_SUCCESS) {
1457 sni.resize(name_len);
1458 return sni;
1459 }
1460 }
1461 return std::string();
1462 }
1463
1464 std::vector<uint8_t> getNextProtocol() const override
1465 {
1466 std::vector<uint8_t> result;
1467 if (!d_conn) {
1468 return result;
1469 }
1470 gnutls_datum_t next;
1471 if (gnutls_alpn_get_selected_protocol(d_conn.get(), &next) != GNUTLS_E_SUCCESS) {
1472 return result;
1473 }
1474 result.insert(result.end(), next.data, next.data + next.size);
1475 return result;
1476 }
1477
1478 LibsslTLSVersion getTLSVersion() const override
1479 {
1480 auto proto = gnutls_protocol_get_version(d_conn.get());
1481 switch (proto) {
1482 case GNUTLS_TLS1_0:
1483 return LibsslTLSVersion::TLS10;
1484 case GNUTLS_TLS1_1:
1485 return LibsslTLSVersion::TLS11;
1486 case GNUTLS_TLS1_2:
1487 return LibsslTLSVersion::TLS12;
1488 #if GNUTLS_VERSION_NUMBER >= 0x030603
1489 case GNUTLS_TLS1_3:
1490 return LibsslTLSVersion::TLS13;
1491 #endif /* GNUTLS_VERSION_NUMBER >= 0x030603 */
1492 default:
1493 return LibsslTLSVersion::Unknown;
1494 }
1495 }
1496
1497 bool hasSessionBeenResumed() const override
1498 {
1499 if (d_conn) {
1500 return gnutls_session_is_resumed(d_conn.get()) != 0;
1501 }
1502 return false;
1503 }
1504
1505 std::vector<std::unique_ptr<TLSSession>> getSessions() override
1506 {
1507 return std::move(d_tlsSessions);
1508 }
1509
1510 void setSession(std::unique_ptr<TLSSession>& session) override
1511 {
1512 auto sess = dynamic_cast<GnuTLSSession*>(session.get());
1513 if (!sess) {
1514 throw std::runtime_error("Unable to convert GnuTLS session");
1515 }
1516
1517 auto native = sess->getNative();
1518 auto ret = gnutls_session_set_data(d_conn.get(), native.data, native.size);
1519 if (ret != GNUTLS_E_SUCCESS) {
1520 throw std::runtime_error("Error setting up GnuTLS session: " + std::string(gnutls_strerror(ret)));
1521 }
1522 session.reset();
1523 }
1524
1525 void close() override
1526 {
1527 if (d_conn) {
1528 gnutls_bye(d_conn.get(), GNUTLS_SHUT_RDWR);
1529 }
1530 }
1531
1532 bool setALPNProtos(const std::vector<std::vector<uint8_t>>& protos)
1533 {
1534 std::vector<gnutls_datum_t> values;
1535 values.reserve(protos.size());
1536 for (const auto& proto : protos) {
1537 gnutls_datum_t value;
1538 value.data = const_cast<uint8_t*>(proto.data());
1539 value.size = proto.size();
1540 values.push_back(value);
1541 }
1542 unsigned int flags = 0;
1543 #if GNUTLS_VERSION_NUMBER >= 0x030500
1544 flags |= GNUTLS_ALPN_MANDATORY;
1545 #elif defined(GNUTLS_ALPN_MAND)
1546 flags |= GNUTLS_ALPN_MAND;
1547 #endif
1548 return gnutls_alpn_set_protocols(d_conn.get(), values.data(), values.size(), flags);
1549 }
1550
1551 std::vector<int> getAsyncFDs() override
1552 {
1553 return {};
1554 }
1555
1556 private:
1557 std::shared_ptr<gnutls_certificate_credentials_st> d_creds;
1558 std::shared_ptr<GnuTLSTicketsKey> d_ticketsKey;
1559 std::unique_ptr<gnutls_session_int, void(*)(gnutls_session_t)> d_conn;
1560 std::vector<std::unique_ptr<TLSSession>> d_tlsSessions;
1561 std::string d_host;
1562 bool d_client{false};
1563 bool d_handshakeDone{false};
1564 };
1565
1566 class GnuTLSIOCtx: public TLSCtx
1567 {
1568 public:
1569 /* server side context */
1570 GnuTLSIOCtx(TLSFrontend& fe): d_enableTickets(fe.d_tlsConfig.d_enableTickets)
1571 {
1572 int rc = 0;
1573 d_ticketsKeyRotationDelay = fe.d_tlsConfig.d_ticketsKeyRotationDelay;
1574
1575 gnutls_certificate_credentials_t creds;
1576 rc = gnutls_certificate_allocate_credentials(&creds);
1577 if (rc != GNUTLS_E_SUCCESS) {
1578 throw std::runtime_error("Error allocating credentials for TLS context on " + fe.d_addr.toStringWithPort() + ": " + gnutls_strerror(rc));
1579 }
1580
1581 d_creds = std::shared_ptr<gnutls_certificate_credentials_st>(creds, gnutls_certificate_free_credentials);
1582 creds = nullptr;
1583
1584 for (const auto& pair : fe.d_tlsConfig.d_certKeyPairs) {
1585 rc = gnutls_certificate_set_x509_key_file(d_creds.get(), pair.d_cert.c_str(), pair.d_key->c_str(), GNUTLS_X509_FMT_PEM);
1586 if (rc != GNUTLS_E_SUCCESS) {
1587 throw std::runtime_error("Error loading certificate ('" + pair.d_cert + "') and key ('" + pair.d_key.value() + "') for TLS context on " + fe.d_addr.toStringWithPort() + ": " + gnutls_strerror(rc));
1588 }
1589 }
1590
1591 #ifndef DISABLE_OCSP_STAPLING
1592 size_t count = 0;
1593 for (const auto& file : fe.d_tlsConfig.d_ocspFiles) {
1594 rc = gnutls_certificate_set_ocsp_status_request_file(d_creds.get(), file.c_str(), count);
1595 if (rc != GNUTLS_E_SUCCESS) {
1596 warnlog("Error loading OCSP response from file '%s' for certificate ('%s') and key ('%s') for TLS context on %s: %s", file, fe.d_tlsConfig.d_certKeyPairs.at(count).d_cert, fe.d_tlsConfig.d_certKeyPairs.at(count).d_key.value(), fe.d_addr.toStringWithPort(), gnutls_strerror(rc));
1597 }
1598 ++count;
1599 }
1600 #endif /* DISABLE_OCSP_STAPLING */
1601
1602 #if GNUTLS_VERSION_NUMBER >= 0x030600
1603 rc = gnutls_certificate_set_known_dh_params(d_creds.get(), GNUTLS_SEC_PARAM_HIGH);
1604 if (rc != GNUTLS_E_SUCCESS) {
1605 throw std::runtime_error("Error setting DH params for TLS context on " + fe.d_addr.toStringWithPort() + ": " + gnutls_strerror(rc));
1606 }
1607 #endif
1608
1609 rc = gnutls_priority_init(&d_priorityCache, fe.d_tlsConfig.d_ciphers.empty() ? "NORMAL" : fe.d_tlsConfig.d_ciphers.c_str(), nullptr);
1610 if (rc != GNUTLS_E_SUCCESS) {
1611 throw std::runtime_error("Error setting up TLS cipher preferences to '" + fe.d_tlsConfig.d_ciphers + "' (" + gnutls_strerror(rc) + ") on " + fe.d_addr.toStringWithPort());
1612 }
1613
1614 try {
1615 if (fe.d_tlsConfig.d_ticketKeyFile.empty()) {
1616 handleTicketsKeyRotation(time(nullptr));
1617 }
1618 else {
1619 GnuTLSIOCtx::loadTicketsKeys(fe.d_tlsConfig.d_ticketKeyFile);
1620 }
1621 }
1622 catch(const std::runtime_error& e) {
1623 throw std::runtime_error("Error generating tickets key for TLS context on " + fe.d_addr.toStringWithPort() + ": " + e.what());
1624 }
1625 }
1626
1627 /* client side context */
1628 GnuTLSIOCtx(const TLSContextParameters& params): d_contextParameters(std::make_unique<TLSContextParameters>(params)), d_enableTickets(true), d_validateCerts(params.d_validateCertificates)
1629 {
1630 int rc = 0;
1631
1632 gnutls_certificate_credentials_t creds;
1633 rc = gnutls_certificate_allocate_credentials(&creds);
1634 if (rc != GNUTLS_E_SUCCESS) {
1635 throw std::runtime_error("Error allocating credentials for TLS context: " + std::string(gnutls_strerror(rc)));
1636 }
1637
1638 d_creds = std::shared_ptr<gnutls_certificate_credentials_st>(creds, gnutls_certificate_free_credentials);
1639 creds = nullptr;
1640
1641 if (params.d_validateCertificates) {
1642 if (params.d_caStore.empty()) {
1643 #if GNUTLS_VERSION_NUMBER >= 0x030700 && GNUTLS_VERSION_NUMBER < 0x030703
1644 /* see https://gitlab.com/gnutls/gnutls/-/issues/1277 */
1645 std::cerr<<"Warning: GnuTLS 3.7.0 - 3.7.2 have a memory leak when validating server certificates in some configurations (PKCS11 support enabled, and a default PKCS11 trust store), please consider upgrading GnuTLS, using the OpenSSL provider for outgoing connections, or explicitly setting a CA store"<<std::endl;
1646 #endif /* GNUTLS_VERSION_NUMBER >= 0x030700 && GNUTLS_VERSION_NUMBER < 0x030703 */
1647 rc = gnutls_certificate_set_x509_system_trust(d_creds.get());
1648 if (rc < 0) {
1649 throw std::runtime_error("Error adding the system's default trusted CAs: " + std::string(gnutls_strerror(rc)));
1650 }
1651 }
1652 else {
1653 rc = gnutls_certificate_set_x509_trust_file(d_creds.get(), params.d_caStore.c_str(), GNUTLS_X509_FMT_PEM);
1654 if (rc < 0) {
1655 throw std::runtime_error("Error adding '" + params.d_caStore + "' to the trusted CAs: " + std::string(gnutls_strerror(rc)));
1656 }
1657 }
1658 }
1659
1660 rc = gnutls_priority_init(&d_priorityCache, params.d_ciphers.empty() ? "NORMAL" : params.d_ciphers.c_str(), nullptr);
1661 if (rc != GNUTLS_E_SUCCESS) {
1662 throw std::runtime_error("Error setting up TLS cipher preferences to 'NORMAL' (" + std::string(gnutls_strerror(rc)) + ")");
1663 }
1664 }
1665
1666 ~GnuTLSIOCtx() override
1667 {
1668 d_creds.reset();
1669
1670 if (d_priorityCache) {
1671 gnutls_priority_deinit(d_priorityCache);
1672 }
1673 }
1674
1675 std::unique_ptr<TLSConnection> getConnection(int socket, const struct timeval& timeout, time_t now) override
1676 {
1677 handleTicketsKeyRotation(now);
1678
1679 std::shared_ptr<GnuTLSTicketsKey> ticketsKey;
1680 {
1681 ticketsKey = *(d_ticketsKey.read_lock());
1682 }
1683
1684 auto connection = std::make_unique<GnuTLSConnection>(socket, timeout, d_creds, d_priorityCache, ticketsKey, d_enableTickets);
1685 if (!d_protos.empty()) {
1686 connection->setALPNProtos(d_protos);
1687 }
1688 return connection;
1689 }
1690
1691 static std::shared_ptr<gnutls_certificate_credentials_st> getPerThreadCredentials(bool validate, const std::string& caStore)
1692 {
1693 static thread_local std::map<std::pair<bool, std::string>, std::shared_ptr<gnutls_certificate_credentials_st>> t_credentials;
1694 auto& entry = t_credentials[{validate, caStore}];
1695 if (!entry) {
1696 gnutls_certificate_credentials_t creds;
1697 int rc = gnutls_certificate_allocate_credentials(&creds);
1698 if (rc != GNUTLS_E_SUCCESS) {
1699 throw std::runtime_error("Error allocating credentials for TLS context: " + std::string(gnutls_strerror(rc)));
1700 }
1701
1702 entry = std::shared_ptr<gnutls_certificate_credentials_st>(creds, gnutls_certificate_free_credentials);
1703 creds = nullptr;
1704
1705 if (validate) {
1706 if (caStore.empty()) {
1707 rc = gnutls_certificate_set_x509_system_trust(entry.get());
1708 if (rc < 0) {
1709 throw std::runtime_error("Error adding the system's default trusted CAs: " + std::string(gnutls_strerror(rc)));
1710 }
1711 }
1712 else {
1713 rc = gnutls_certificate_set_x509_trust_file(entry.get(), caStore.c_str(), GNUTLS_X509_FMT_PEM);
1714 if (rc < 0) {
1715 throw std::runtime_error("Error adding '" + caStore + "' to the trusted CAs: " + std::string(gnutls_strerror(rc)));
1716 }
1717 }
1718 }
1719 }
1720 return entry;
1721 }
1722
1723 std::unique_ptr<TLSConnection> getClientConnection(const std::string& host, bool, int socket, const struct timeval& timeout) override
1724 {
1725 auto creds = getPerThreadCredentials(d_contextParameters->d_validateCertificates, d_contextParameters->d_caStore);
1726 auto connection = std::make_unique<GnuTLSConnection>(host, socket, timeout, creds, d_priorityCache, d_validateCerts);
1727 if (!d_protos.empty()) {
1728 connection->setALPNProtos(d_protos);
1729 }
1730 return connection;
1731 }
1732
1733 void rotateTicketsKey(time_t now) override
1734 {
1735 if (!d_enableTickets) {
1736 return;
1737 }
1738
1739 auto newKey = std::make_shared<GnuTLSTicketsKey>();
1740
1741 {
1742 *(d_ticketsKey.write_lock()) = std::move(newKey);
1743 }
1744
1745 if (d_ticketsKeyRotationDelay > 0) {
1746 d_ticketsKeyNextRotation = now + d_ticketsKeyRotationDelay;
1747 }
1748 }
1749
1750 void loadTicketsKeys(const std::string& file) final
1751 {
1752 if (!d_enableTickets) {
1753 return;
1754 }
1755
1756 auto newKey = std::make_shared<GnuTLSTicketsKey>(file);
1757 {
1758 *(d_ticketsKey.write_lock()) = std::move(newKey);
1759 }
1760
1761 if (d_ticketsKeyRotationDelay > 0) {
1762 d_ticketsKeyNextRotation = time(nullptr) + d_ticketsKeyRotationDelay;
1763 }
1764 }
1765
1766 size_t getTicketsKeysCount() override
1767 {
1768 return *(d_ticketsKey.read_lock()) != nullptr ? 1 : 0;
1769 }
1770
1771 std::string getName() const override
1772 {
1773 return "gnutls";
1774 }
1775
1776 bool setALPNProtos(const std::vector<std::vector<uint8_t>>& protos) override
1777 {
1778 #ifdef HAVE_GNUTLS_ALPN_SET_PROTOCOLS
1779 d_protos = protos;
1780 return true;
1781 #else
1782 return false;
1783 #endif
1784 }
1785
1786 private:
1787 /* client context parameters */
1788 std::unique_ptr<TLSContextParameters> d_contextParameters{nullptr};
1789 std::shared_ptr<gnutls_certificate_credentials_st> d_creds;
1790 std::vector<std::vector<uint8_t>> d_protos;
1791 gnutls_priority_t d_priorityCache{nullptr};
1792 SharedLockGuarded<std::shared_ptr<GnuTLSTicketsKey>> d_ticketsKey{nullptr};
1793 bool d_enableTickets{true};
1794 bool d_validateCerts{true};
1795 };
1796
1797 #endif /* HAVE_GNUTLS */
1798
1799 #endif /* HAVE_DNS_OVER_TLS || HAVE_DNS_OVER_HTTPS */
1800
1801 bool setupDoTProtocolNegotiation(std::shared_ptr<TLSCtx>& ctx)
1802 {
1803 if (ctx == nullptr) {
1804 return false;
1805 }
1806 /* we want to set the ALPN to dot (RFC7858), if only to mitigate the ALPACA attack */
1807 const std::vector<std::vector<uint8_t>> dotAlpns = {{'d', 'o', 't'}};
1808 ctx->setALPNProtos(dotAlpns);
1809 return true;
1810 }
1811
1812 bool setupDoHProtocolNegotiation(std::shared_ptr<TLSCtx>& ctx)
1813 {
1814 if (ctx == nullptr) {
1815 return false;
1816 }
1817 /* This code is only called for incoming/server TLS contexts (not outgoing/client),
1818 and h2o sets it own ALPN values.
1819 We want to set the ALPN for DoH:
1820 - HTTP/1.1 so that the OpenSSL callback ALPN accepts it, letting us later return a static response
1821 - HTTP/2
1822 */
1823 const std::vector<std::vector<uint8_t>> dohAlpns{{'h', '2'},{'h', 't', 't', 'p', '/', '1', '.', '1'}};
1824 ctx->setALPNProtos(dohAlpns);
1825
1826 return true;
1827 }
1828
1829 bool TLSFrontend::setupTLS()
1830 {
1831 #if defined(HAVE_DNS_OVER_TLS) || defined(HAVE_DNS_OVER_HTTPS)
1832 std::shared_ptr<TLSCtx> newCtx{nullptr};
1833 /* get the "best" available provider */
1834 #if defined(HAVE_GNUTLS)
1835 if (d_provider == "gnutls") {
1836 newCtx = std::make_shared<GnuTLSIOCtx>(*this);
1837 }
1838 #endif /* HAVE_GNUTLS */
1839 #if defined(HAVE_LIBSSL)
1840 if (d_provider == "openssl") {
1841 newCtx = std::make_shared<OpenSSLTLSIOCtx>(*this);
1842 }
1843 #endif /* HAVE_LIBSSL */
1844
1845 if (!newCtx) {
1846 #if defined(HAVE_LIBSSL)
1847 newCtx = std::make_shared<OpenSSLTLSIOCtx>(*this);
1848 #elif defined(HAVE_GNUTLS)
1849 newCtx = std::make_shared<GnuTLSIOCtx>(*this);
1850 #else
1851 #error "TLS support needed but neither libssl nor GnuTLS were selected"
1852 #endif
1853 }
1854
1855 if (d_alpn == ALPN::DoT) {
1856 setupDoTProtocolNegotiation(newCtx);
1857 }
1858 else if (d_alpn == ALPN::DoH) {
1859 setupDoHProtocolNegotiation(newCtx);
1860 }
1861
1862 std::atomic_store_explicit(&d_ctx, std::move(newCtx), std::memory_order_release);
1863 #endif /* HAVE_DNS_OVER_TLS || HAVE_DNS_OVER_HTTPS */
1864 return true;
1865 }
1866
1867 std::shared_ptr<TLSCtx> getTLSContext([[maybe_unused]] const TLSContextParameters& params)
1868 {
1869 #ifdef HAVE_DNS_OVER_TLS
1870 /* get the "best" available provider */
1871 if (!params.d_provider.empty()) {
1872 #if defined(HAVE_GNUTLS)
1873 if (params.d_provider == "gnutls") {
1874 return std::make_shared<GnuTLSIOCtx>(params);
1875 }
1876 #endif /* HAVE_GNUTLS */
1877 #if defined(HAVE_LIBSSL)
1878 if (params.d_provider == "openssl") {
1879 return std::make_shared<OpenSSLTLSIOCtx>(params);
1880 }
1881 #endif /* HAVE_LIBSSL */
1882 }
1883
1884 #if defined(HAVE_LIBSSL)
1885 return std::make_shared<OpenSSLTLSIOCtx>(params);
1886 #elif defined(HAVE_GNUTLS)
1887 return std::make_shared<GnuTLSIOCtx>(params);
1888 #else
1889 #error "DNS over TLS support needed but neither libssl nor GnuTLS were selected"
1890 #endif
1891
1892 #endif /* HAVE_DNS_OVER_TLS */
1893 return nullptr;
1894 }