]> git.ipfire.org Git - thirdparty/pdns.git/blob - pdns/dnsdistdist/tcpiohandler.cc
dnsdist: Wrap pthread_ objects
[thirdparty/pdns.git] / pdns / dnsdistdist / tcpiohandler.cc
1
2 #include "config.h"
3 #include "dolog.hh"
4 #include "iputils.hh"
5 #include "lock.hh"
6 #include "tcpiohandler.hh"
7
8 #ifdef HAVE_LIBSODIUM
9 #include <sodium.h>
10 #endif /* HAVE_LIBSODIUM */
11
12 #ifdef HAVE_DNS_OVER_TLS
13 #ifdef HAVE_LIBSSL
14 #include <openssl/conf.h>
15 #include <openssl/err.h>
16 #include <openssl/rand.h>
17 #include <openssl/ssl.h>
18
19 #include "libssl.hh"
20
21 class OpenSSLFrontendContext
22 {
23 public:
24 OpenSSLFrontendContext(const ComboAddress& addr, const TLSConfig& tlsConfig): d_ticketKeys(tlsConfig.d_numberOfTicketsKeys)
25 {
26 registerOpenSSLUser();
27
28 d_tlsCtx = libssl_init_server_context(tlsConfig, d_ocspResponses);
29 if (!d_tlsCtx) {
30 ERR_print_errors_fp(stderr);
31 throw std::runtime_error("Error creating TLS context on " + addr.toStringWithPort());
32 }
33 }
34
35 void cleanup()
36 {
37 d_tlsCtx.reset();
38
39 unregisterOpenSSLUser();
40 }
41
42 OpenSSLTLSTicketKeysRing d_ticketKeys;
43 std::map<int, std::string> d_ocspResponses;
44 std::unique_ptr<SSL_CTX, void(*)(SSL_CTX*)> d_tlsCtx{nullptr, SSL_CTX_free};
45 std::unique_ptr<FILE, int(*)(FILE*)> d_keyLogFile{nullptr, fclose};
46 };
47
48 class OpenSSLTLSConnection: public TLSConnection
49 {
50 public:
51 OpenSSLTLSConnection(int socket, unsigned int timeout, std::shared_ptr<OpenSSLFrontendContext> feContext): d_feContext(feContext), d_conn(std::unique_ptr<SSL, void(*)(SSL*)>(SSL_new(d_feContext->d_tlsCtx.get()), SSL_free)), d_timeout(timeout)
52 {
53 d_socket = socket;
54
55 if (!s_initTLSConnIndex.test_and_set()) {
56 /* not initialized yet */
57 s_tlsConnIndex = SSL_get_ex_new_index(0, nullptr, nullptr, nullptr, nullptr);
58 if (s_tlsConnIndex == -1) {
59 throw std::runtime_error("Error getting an index for TLS connection data");
60 }
61 }
62
63 if (!d_conn) {
64 vinfolog("Error creating TLS object");
65 if (g_verbose) {
66 ERR_print_errors_fp(stderr);
67 }
68 throw std::runtime_error("Error creating TLS object");
69 }
70
71 if (!SSL_set_fd(d_conn.get(), d_socket)) {
72 throw std::runtime_error("Error assigning socket");
73 }
74
75 SSL_set_ex_data(d_conn.get(), s_tlsConnIndex, this);
76 }
77
78 IOState convertIORequestToIOState(int res) const
79 {
80 int error = SSL_get_error(d_conn.get(), res);
81 if (error == SSL_ERROR_WANT_READ) {
82 return IOState::NeedRead;
83 }
84 else if (error == SSL_ERROR_WANT_WRITE) {
85 return IOState::NeedWrite;
86 }
87 else if (error == SSL_ERROR_SYSCALL) {
88 throw std::runtime_error("Error while processing TLS connection: " + std::string(strerror(errno)));
89 }
90 else {
91 throw std::runtime_error("Error while processing TLS connection: " + std::to_string(error));
92 }
93 }
94
95 void handleIORequest(int res, unsigned int timeout)
96 {
97 auto state = convertIORequestToIOState(res);
98 if (state == IOState::NeedRead) {
99 res = waitForData(d_socket, timeout);
100 if (res == 0) {
101 throw std::runtime_error("Timeout while reading from TLS connection");
102 }
103 else if (res < 0) {
104 throw std::runtime_error("Error waiting to read from TLS connection");
105 }
106 }
107 else if (state == IOState::NeedWrite) {
108 res = waitForRWData(d_socket, false, timeout, 0);
109 if (res == 0) {
110 throw std::runtime_error("Timeout while writing to TLS connection");
111 }
112 else if (res < 0) {
113 throw std::runtime_error("Error waiting to write to TLS connection");
114 }
115 }
116 }
117
118 IOState tryHandshake() override
119 {
120 int res = SSL_accept(d_conn.get());
121 if (res == 1) {
122 return IOState::Done;
123 }
124 else if (res < 0) {
125 return convertIORequestToIOState(res);
126 }
127
128 throw std::runtime_error("Error accepting TLS connection");
129 }
130
131 void doHandshake() override
132 {
133 int res = 0;
134 do {
135 res = SSL_accept(d_conn.get());
136 if (res < 0) {
137 handleIORequest(res, d_timeout);
138 }
139 }
140 while (res < 0);
141
142 if (res != 1) {
143 throw std::runtime_error("Error accepting TLS connection");
144 }
145 }
146
147 IOState tryWrite(std::vector<uint8_t>& buffer, size_t& pos, size_t toWrite) override
148 {
149 do {
150 int res = SSL_write(d_conn.get(), reinterpret_cast<const char *>(&buffer.at(pos)), static_cast<int>(toWrite - pos));
151 if (res <= 0) {
152 return convertIORequestToIOState(res);
153 }
154 else {
155 pos += static_cast<size_t>(res);
156 }
157 }
158 while (pos < toWrite);
159 return IOState::Done;
160 }
161
162 IOState tryRead(std::vector<uint8_t>& buffer, size_t& pos, size_t toRead) override
163 {
164 do {
165 int res = SSL_read(d_conn.get(), reinterpret_cast<char *>(&buffer.at(pos)), static_cast<int>(toRead - pos));
166 if (res <= 0) {
167 return convertIORequestToIOState(res);
168 }
169 else {
170 pos += static_cast<size_t>(res);
171 }
172 }
173 while (pos < toRead);
174 return IOState::Done;
175 }
176
177 size_t read(void* buffer, size_t bufferSize, unsigned int readTimeout, unsigned int totalTimeout) override
178 {
179 size_t got = 0;
180 time_t start = 0;
181 unsigned int remainingTime = totalTimeout;
182 if (totalTimeout) {
183 start = time(nullptr);
184 }
185
186 do {
187 int res = SSL_read(d_conn.get(), (reinterpret_cast<char *>(buffer) + got), static_cast<int>(bufferSize - got));
188 if (res <= 0) {
189 handleIORequest(res, readTimeout);
190 }
191 else {
192 got += static_cast<size_t>(res);
193 }
194
195 if (totalTimeout) {
196 time_t now = time(nullptr);
197 unsigned int elapsed = now - start;
198 if (now < start || elapsed >= remainingTime) {
199 throw runtime_error("Timeout while reading data");
200 }
201 start = now;
202 remainingTime -= elapsed;
203 }
204 }
205 while (got < bufferSize);
206
207 return got;
208 }
209
210 size_t write(const void* buffer, size_t bufferSize, unsigned int writeTimeout) override
211 {
212 size_t got = 0;
213 do {
214 int res = SSL_write(d_conn.get(), (reinterpret_cast<const char *>(buffer) + got), static_cast<int>(bufferSize - got));
215 if (res <= 0) {
216 handleIORequest(res, writeTimeout);
217 }
218 else {
219 got += static_cast<size_t>(res);
220 }
221 }
222 while (got < bufferSize);
223
224 return got;
225 }
226
227 void close() override
228 {
229 if (d_conn) {
230 SSL_shutdown(d_conn.get());
231 }
232 }
233
234 std::string getServerNameIndication() const override
235 {
236 if (d_conn) {
237 const char* value = SSL_get_servername(d_conn.get(), TLSEXT_NAMETYPE_host_name);
238 if (value) {
239 return std::string(value);
240 }
241 }
242 return std::string();
243 }
244
245 LibsslTLSVersion getTLSVersion() const override
246 {
247 auto proto = SSL_version(d_conn.get());
248 switch (proto) {
249 case TLS1_VERSION:
250 return LibsslTLSVersion::TLS10;
251 case TLS1_1_VERSION:
252 return LibsslTLSVersion::TLS11;
253 case TLS1_2_VERSION:
254 return LibsslTLSVersion::TLS12;
255 #ifdef TLS1_3_VERSION
256 case TLS1_3_VERSION:
257 return LibsslTLSVersion::TLS13;
258 #endif /* TLS1_3_VERSION */
259 default:
260 return LibsslTLSVersion::Unknown;
261 }
262 }
263
264 bool hasSessionBeenResumed() const override
265 {
266 if (d_conn) {
267 return SSL_session_reused(d_conn.get()) != 0;
268 }
269 return false;
270 }
271
272 static int s_tlsConnIndex;
273
274 private:
275 static std::atomic_flag s_initTLSConnIndex;
276
277 std::shared_ptr<OpenSSLFrontendContext> d_feContext;
278 std::unique_ptr<SSL, void(*)(SSL*)> d_conn;
279 unsigned int d_timeout;
280 };
281
282 std::atomic_flag OpenSSLTLSConnection::s_initTLSConnIndex = ATOMIC_FLAG_INIT;
283 int OpenSSLTLSConnection::s_tlsConnIndex = -1;
284
285 class OpenSSLTLSIOCtx: public TLSCtx
286 {
287 public:
288 OpenSSLTLSIOCtx(TLSFrontend& fe)
289 {
290 d_feContext = std::make_shared<OpenSSLFrontendContext>(fe.d_addr, fe.d_tlsConfig);
291
292 d_ticketsKeyRotationDelay = fe.d_tlsConfig.d_ticketsKeyRotationDelay;
293
294 if (fe.d_tlsConfig.d_enableTickets && fe.d_tlsConfig.d_numberOfTicketsKeys > 0) {
295 /* use our own ticket keys handler so we can rotate them */
296 SSL_CTX_set_tlsext_ticket_key_cb(d_feContext->d_tlsCtx.get(), &OpenSSLTLSIOCtx::ticketKeyCb);
297 libssl_set_ticket_key_callback_data(d_feContext->d_tlsCtx.get(), d_feContext.get());
298 }
299
300 if (!d_feContext->d_ocspResponses.empty()) {
301 SSL_CTX_set_tlsext_status_cb(d_feContext->d_tlsCtx.get(), &OpenSSLTLSIOCtx::ocspStaplingCb);
302 SSL_CTX_set_tlsext_status_arg(d_feContext->d_tlsCtx.get(), &d_feContext->d_ocspResponses);
303 }
304
305 libssl_set_error_counters_callback(d_feContext->d_tlsCtx, &fe.d_tlsCounters);
306
307 if (!fe.d_tlsConfig.d_keyLogFile.empty()) {
308 d_feContext->d_keyLogFile = libssl_set_key_log_file(d_feContext->d_tlsCtx, fe.d_tlsConfig.d_keyLogFile);
309 }
310
311 try {
312 if (fe.d_tlsConfig.d_ticketKeyFile.empty()) {
313 handleTicketsKeyRotation(time(nullptr));
314 }
315 else {
316 OpenSSLTLSIOCtx::loadTicketsKeys(fe.d_tlsConfig.d_ticketKeyFile);
317 }
318 }
319 catch (const std::exception& e) {
320 throw;
321 }
322 }
323
324 ~OpenSSLTLSIOCtx() override
325 {
326 }
327
328 static int ticketKeyCb(SSL *s, unsigned char keyName[TLS_TICKETS_KEY_NAME_SIZE], unsigned char *iv, EVP_CIPHER_CTX *ectx, HMAC_CTX *hctx, int enc)
329 {
330 OpenSSLFrontendContext* ctx = reinterpret_cast<OpenSSLFrontendContext*>(libssl_get_ticket_key_callback_data(s));
331 if (ctx == nullptr) {
332 return -1;
333 }
334
335 int ret = libssl_ticket_key_callback(s, ctx->d_ticketKeys, keyName, iv, ectx, hctx, enc);
336 if (enc == 0) {
337 if (ret == 0 || ret == 2) {
338 OpenSSLTLSConnection* conn = reinterpret_cast<OpenSSLTLSConnection*>(SSL_get_ex_data(s, OpenSSLTLSConnection::s_tlsConnIndex));
339 if (conn) {
340 if (ret == 0) {
341 conn->setUnknownTicketKey();
342 }
343 else if (ret == 2) {
344 conn->setResumedFromInactiveTicketKey();
345 }
346 }
347 }
348 }
349
350 return ret;
351 }
352
353 static int ocspStaplingCb(SSL* ssl, void* arg)
354 {
355 if (ssl == nullptr || arg == nullptr) {
356 return SSL_TLSEXT_ERR_NOACK;
357 }
358 const auto ocspMap = reinterpret_cast<std::map<int, std::string>*>(arg);
359 return libssl_ocsp_stapling_callback(ssl, *ocspMap);
360 }
361
362 std::unique_ptr<TLSConnection> getConnection(int socket, unsigned int timeout, time_t now) override
363 {
364 handleTicketsKeyRotation(now);
365
366 return std::unique_ptr<OpenSSLTLSConnection>(new OpenSSLTLSConnection(socket, timeout, d_feContext));
367 }
368
369 void rotateTicketsKey(time_t now) override
370 {
371 d_feContext->d_ticketKeys.rotateTicketsKey(now);
372
373 if (d_ticketsKeyRotationDelay > 0) {
374 d_ticketsKeyNextRotation = now + d_ticketsKeyRotationDelay;
375 }
376 }
377
378 void loadTicketsKeys(const std::string& keyFile) override final
379 {
380 d_feContext->d_ticketKeys.loadTicketsKeys(keyFile);
381
382 if (d_ticketsKeyRotationDelay > 0) {
383 d_ticketsKeyNextRotation = time(nullptr) + d_ticketsKeyRotationDelay;
384 }
385 }
386
387 size_t getTicketsKeysCount() override
388 {
389 return d_feContext->d_ticketKeys.getKeysCount();
390 }
391
392 private:
393 std::shared_ptr<OpenSSLFrontendContext> d_feContext;
394 };
395
396 #endif /* HAVE_LIBSSL */
397
398 #ifdef HAVE_GNUTLS
399 #include <gnutls/gnutls.h>
400 #include <gnutls/x509.h>
401
402 static void safe_memory_lock(void* data, size_t size)
403 {
404 #ifdef HAVE_LIBSODIUM
405 sodium_mlock(data, size);
406 #endif
407 }
408
409 static void safe_memory_release(void* data, size_t size)
410 {
411 #ifdef HAVE_LIBSODIUM
412 sodium_munlock(data, size);
413 #elif defined(HAVE_EXPLICIT_BZERO)
414 explicit_bzero(data, size);
415 #elif defined(HAVE_EXPLICIT_MEMSET)
416 explicit_memset(data, 0, size);
417 #elif defined(HAVE_GNUTLS_MEMSET)
418 gnutls_memset(data, 0, size);
419 #else
420 /* shamelessly taken from Dovecot's src/lib/safe-memset.c */
421 volatile unsigned int volatile_zero_idx = 0;
422 volatile unsigned char *p = reinterpret_cast<volatile unsigned char *>(data);
423
424 if (size == 0)
425 return;
426
427 do {
428 memset(data, 0, size);
429 } while (p[volatile_zero_idx] != 0);
430 #endif
431 }
432
433 class GnuTLSTicketsKey
434 {
435 public:
436 GnuTLSTicketsKey()
437 {
438 if (gnutls_session_ticket_key_generate(&d_key) != GNUTLS_E_SUCCESS) {
439 throw std::runtime_error("Error generating tickets key for TLS context");
440 }
441
442 safe_memory_lock(d_key.data, d_key.size);
443 }
444
445 GnuTLSTicketsKey(const std::string& keyFile)
446 {
447 /* to be sure we are loading the correct amount of data, which
448 may change between versions, let's generate a correct key first */
449 if (gnutls_session_ticket_key_generate(&d_key) != GNUTLS_E_SUCCESS) {
450 throw std::runtime_error("Error generating tickets key (before parsing key file) for TLS context");
451 }
452
453 safe_memory_lock(d_key.data, d_key.size);
454
455 try {
456 ifstream file(keyFile);
457 file.read(reinterpret_cast<char*>(d_key.data), d_key.size);
458
459 if (file.fail()) {
460 file.close();
461 throw std::runtime_error("Invalid GnuTLS tickets key file " + keyFile);
462 }
463
464 file.close();
465 }
466 catch (const std::exception& e) {
467 safe_memory_release(d_key.data, d_key.size);
468 gnutls_free(d_key.data);
469 d_key.data = nullptr;
470 throw;
471 }
472 }
473
474 ~GnuTLSTicketsKey()
475 {
476 if (d_key.data != nullptr && d_key.size > 0) {
477 safe_memory_release(d_key.data, d_key.size);
478 }
479 gnutls_free(d_key.data);
480 d_key.data = nullptr;
481 }
482 const gnutls_datum_t& getKey() const
483 {
484 return d_key;
485 }
486
487 private:
488 gnutls_datum_t d_key{nullptr, 0};
489 };
490
491 class GnuTLSConnection: public TLSConnection
492 {
493 public:
494
495 GnuTLSConnection(int socket, unsigned int timeout, const gnutls_certificate_credentials_t creds, const gnutls_priority_t priorityCache, std::shared_ptr<GnuTLSTicketsKey>& ticketsKey, bool enableTickets): d_conn(std::unique_ptr<gnutls_session_int, void(*)(gnutls_session_t)>(nullptr, gnutls_deinit)), d_ticketsKey(ticketsKey)
496 {
497 unsigned int sslOptions = GNUTLS_SERVER | GNUTLS_NONBLOCK;
498 #ifdef GNUTLS_NO_SIGNAL
499 sslOptions |= GNUTLS_NO_SIGNAL;
500 #endif
501
502 d_socket = socket;
503
504 gnutls_session_t conn;
505 if (gnutls_init(&conn, sslOptions) != GNUTLS_E_SUCCESS) {
506 throw std::runtime_error("Error creating TLS connection");
507 }
508
509 d_conn = std::unique_ptr<gnutls_session_int, void(*)(gnutls_session_t)>(conn, gnutls_deinit);
510 conn = nullptr;
511
512 if (gnutls_credentials_set(d_conn.get(), GNUTLS_CRD_CERTIFICATE, creds) != GNUTLS_E_SUCCESS) {
513 throw std::runtime_error("Error setting certificate and key to TLS connection");
514 }
515
516 if (gnutls_priority_set(d_conn.get(), priorityCache) != GNUTLS_E_SUCCESS) {
517 throw std::runtime_error("Error setting ciphers to TLS connection");
518 }
519
520 if (enableTickets && d_ticketsKey) {
521 const gnutls_datum_t& key = d_ticketsKey->getKey();
522 if (gnutls_session_ticket_enable_server(d_conn.get(), &key) != GNUTLS_E_SUCCESS) {
523 throw std::runtime_error("Error setting the tickets key to TLS connection");
524 }
525 }
526
527 gnutls_transport_set_int(d_conn.get(), d_socket);
528
529 /* timeouts are in milliseconds */
530 gnutls_handshake_set_timeout(d_conn.get(), timeout * 1000);
531 gnutls_record_set_timeout(d_conn.get(), timeout * 1000);
532 }
533
534 void doHandshake() override
535 {
536 int ret = 0;
537 do {
538 ret = gnutls_handshake(d_conn.get());
539 if (gnutls_error_is_fatal(ret) || ret == GNUTLS_E_WARNING_ALERT_RECEIVED) {
540 throw std::runtime_error("Error accepting a new connection");
541 }
542 }
543 while (ret < 0 && ret == GNUTLS_E_INTERRUPTED);
544 }
545
546 IOState tryHandshake() override
547 {
548 int ret = 0;
549
550 do {
551 ret = gnutls_handshake(d_conn.get());
552 if (ret == GNUTLS_E_SUCCESS) {
553 return IOState::Done;
554 }
555 else if (ret == GNUTLS_E_AGAIN) {
556 return IOState::NeedRead;
557 }
558 else if (gnutls_error_is_fatal(ret) || ret == GNUTLS_E_WARNING_ALERT_RECEIVED) {
559 throw std::runtime_error("Error accepting a new connection");
560 }
561 } while (ret == GNUTLS_E_INTERRUPTED);
562
563 throw std::runtime_error("Error accepting a new connection");
564 }
565
566 IOState tryWrite(std::vector<uint8_t>& buffer, size_t& pos, size_t toWrite) override
567 {
568 do {
569 ssize_t res = gnutls_record_send(d_conn.get(), reinterpret_cast<const char *>(&buffer.at(pos)), toWrite - pos);
570 if (res == 0) {
571 throw std::runtime_error("Error writing to TLS connection");
572 }
573 else if (res > 0) {
574 pos += static_cast<size_t>(res);
575 }
576 else if (res < 0) {
577 if (gnutls_error_is_fatal(res)) {
578 throw std::runtime_error("Fatal error writing to TLS connection: " + std::string(gnutls_strerror(res)));
579 }
580 else if (res == GNUTLS_E_AGAIN) {
581 return IOState::NeedWrite;
582 }
583 warnlog("Warning, non-fatal error while writing to TLS connection: %s", gnutls_strerror(res));
584 }
585 }
586 while (pos < toWrite);
587 return IOState::Done;
588 }
589
590 IOState tryRead(std::vector<uint8_t>& buffer, size_t& pos, size_t toRead) override
591 {
592 do {
593 ssize_t res = gnutls_record_recv(d_conn.get(), reinterpret_cast<char *>(&buffer.at(pos)), toRead - pos);
594 if (res == 0) {
595 throw std::runtime_error("Error reading from TLS connection");
596 }
597 else if (res > 0) {
598 pos += static_cast<size_t>(res);
599 }
600 else if (res < 0) {
601 if (gnutls_error_is_fatal(res)) {
602 throw std::runtime_error("Fatal error reading from TLS connection: " + std::string(gnutls_strerror(res)));
603 }
604 else if (res == GNUTLS_E_AGAIN) {
605 return IOState::NeedRead;
606 }
607 warnlog("Warning, non-fatal error while writing to TLS connection: %s", gnutls_strerror(res));
608 }
609 }
610 while (pos < toRead);
611 return IOState::Done;
612 }
613
614 size_t read(void* buffer, size_t bufferSize, unsigned int readTimeout, unsigned int totalTimeout) override
615 {
616 size_t got = 0;
617 time_t start = 0;
618 unsigned int remainingTime = totalTimeout;
619 if (totalTimeout) {
620 start = time(nullptr);
621 }
622
623 do {
624 ssize_t res = gnutls_record_recv(d_conn.get(), (reinterpret_cast<char *>(buffer) + got), bufferSize - got);
625 if (res == 0) {
626 throw std::runtime_error("Error reading from TLS connection");
627 }
628 else if (res > 0) {
629 got += static_cast<size_t>(res);
630 }
631 else if (res < 0) {
632 if (gnutls_error_is_fatal(res)) {
633 throw std::runtime_error("Fatal error reading from TLS connection: " + std::string(gnutls_strerror(res)));
634 }
635 else if (res == GNUTLS_E_AGAIN) {
636 int result = waitForData(d_socket, readTimeout);
637 if (result <= 0) {
638 throw std::runtime_error("Error while waiting to read from TLS connection: " + std::to_string(result));
639 }
640 }
641 else {
642 vinfolog("Non-fatal error while reading from TLS connection: %s", gnutls_strerror(res));
643 }
644 }
645
646 if (totalTimeout) {
647 time_t now = time(nullptr);
648 unsigned int elapsed = now - start;
649 if (now < start || elapsed >= remainingTime) {
650 throw runtime_error("Timeout while reading data");
651 }
652 start = now;
653 remainingTime -= elapsed;
654 }
655 }
656 while (got < bufferSize);
657
658 return got;
659 }
660
661 size_t write(const void* buffer, size_t bufferSize, unsigned int writeTimeout) override
662 {
663 size_t got = 0;
664
665 do {
666 ssize_t res = gnutls_record_send(d_conn.get(), (reinterpret_cast<const char *>(buffer) + got), bufferSize - got);
667 if (res == 0) {
668 throw std::runtime_error("Error writing to TLS connection");
669 }
670 else if (res > 0) {
671 got += static_cast<size_t>(res);
672 }
673 else if (res < 0) {
674 if (gnutls_error_is_fatal(res)) {
675 throw std::runtime_error("Fatal error writing to TLS connection: " + std::string(gnutls_strerror(res)));
676 }
677 else if (res == GNUTLS_E_AGAIN) {
678 int result = waitForRWData(d_socket, false, writeTimeout, 0);
679 if (result <= 0) {
680 throw std::runtime_error("Error waiting to write to TLS connection: " + std::to_string(result));
681 }
682 }
683 else {
684 vinfolog("Non-fatal error while writing to TLS connection: %s", gnutls_strerror(res));
685 }
686 }
687 }
688 while (got < bufferSize);
689
690 return got;
691 }
692
693 std::string getServerNameIndication() const override
694 {
695 if (d_conn) {
696 unsigned int type;
697 size_t name_len = 256;
698 std::string sni;
699 sni.resize(name_len);
700
701 int res = gnutls_server_name_get(d_conn.get(), const_cast<char*>(sni.c_str()), &name_len, &type, 0);
702 if (res == GNUTLS_E_SUCCESS) {
703 sni.resize(name_len);
704 return sni;
705 }
706 }
707 return std::string();
708 }
709
710 LibsslTLSVersion getTLSVersion() const override
711 {
712 auto proto = gnutls_protocol_get_version(d_conn.get());
713 switch (proto) {
714 case GNUTLS_TLS1_0:
715 return LibsslTLSVersion::TLS10;
716 case GNUTLS_TLS1_1:
717 return LibsslTLSVersion::TLS11;
718 case GNUTLS_TLS1_2:
719 return LibsslTLSVersion::TLS12;
720 #if GNUTLS_VERSION_NUMBER >= 0x030603
721 case GNUTLS_TLS1_3:
722 return LibsslTLSVersion::TLS13;
723 #endif /* GNUTLS_VERSION_NUMBER >= 0x030603 */
724 default:
725 return LibsslTLSVersion::Unknown;
726 }
727 }
728
729 bool hasSessionBeenResumed() const override
730 {
731 if (d_conn) {
732 return gnutls_session_is_resumed(d_conn.get()) != 0;
733 }
734 return false;
735 }
736
737 void close() override
738 {
739 if (d_conn) {
740 gnutls_bye(d_conn.get(), GNUTLS_SHUT_WR);
741 }
742 }
743
744 private:
745 std::unique_ptr<gnutls_session_int, void(*)(gnutls_session_t)> d_conn;
746 std::shared_ptr<GnuTLSTicketsKey> d_ticketsKey;
747 };
748
749 class GnuTLSIOCtx: public TLSCtx
750 {
751 public:
752 GnuTLSIOCtx(TLSFrontend& fe): d_creds(std::unique_ptr<gnutls_certificate_credentials_st, void(*)(gnutls_certificate_credentials_t)>(nullptr, gnutls_certificate_free_credentials)), d_enableTickets(fe.d_tlsConfig.d_enableTickets)
753 {
754 int rc = 0;
755 d_ticketsKeyRotationDelay = fe.d_tlsConfig.d_ticketsKeyRotationDelay;
756
757 gnutls_certificate_credentials_t creds;
758 rc = gnutls_certificate_allocate_credentials(&creds);
759 if (rc != GNUTLS_E_SUCCESS) {
760 throw std::runtime_error("Error allocating credentials for TLS context on " + fe.d_addr.toStringWithPort() + ": " + gnutls_strerror(rc));
761 }
762
763 d_creds = std::unique_ptr<gnutls_certificate_credentials_st, void(*)(gnutls_certificate_credentials_t)>(creds, gnutls_certificate_free_credentials);
764 creds = nullptr;
765
766 for (const auto& pair : fe.d_tlsConfig.d_certKeyPairs) {
767 rc = gnutls_certificate_set_x509_key_file(d_creds.get(), pair.first.c_str(), pair.second.c_str(), GNUTLS_X509_FMT_PEM);
768 if (rc != GNUTLS_E_SUCCESS) {
769 throw std::runtime_error("Error loading certificate ('" + pair.first + "') and key ('" + pair.second + "') for TLS context on " + fe.d_addr.toStringWithPort() + ": " + gnutls_strerror(rc));
770 }
771 }
772
773 size_t count = 0;
774 for (const auto& file : fe.d_tlsConfig.d_ocspFiles) {
775 rc = gnutls_certificate_set_ocsp_status_request_file(d_creds.get(), file.c_str(), count);
776 if (rc != GNUTLS_E_SUCCESS) {
777 throw std::runtime_error("Error loading OCSP response from file '" + file + "' for certificate ('" + fe.d_tlsConfig.d_certKeyPairs.at(count).first + "') and key ('" + fe.d_tlsConfig.d_certKeyPairs.at(count).second + "') for TLS context on " + fe.d_addr.toStringWithPort() + ": " + gnutls_strerror(rc));
778 }
779 ++count;
780 }
781
782 #if GNUTLS_VERSION_NUMBER >= 0x030600
783 rc = gnutls_certificate_set_known_dh_params(d_creds.get(), GNUTLS_SEC_PARAM_HIGH);
784 if (rc != GNUTLS_E_SUCCESS) {
785 throw std::runtime_error("Error setting DH params for TLS context on " + fe.d_addr.toStringWithPort() + ": " + gnutls_strerror(rc));
786 }
787 #endif
788
789 rc = gnutls_priority_init(&d_priorityCache, fe.d_tlsConfig.d_ciphers.empty() ? "NORMAL" : fe.d_tlsConfig.d_ciphers.c_str(), nullptr);
790 if (rc != GNUTLS_E_SUCCESS) {
791 throw std::runtime_error("Error setting up TLS cipher preferences to '" + fe.d_tlsConfig.d_ciphers + "' (" + gnutls_strerror(rc) + ") on " + fe.d_addr.toStringWithPort());
792 }
793
794 try {
795 if (fe.d_tlsConfig.d_ticketKeyFile.empty()) {
796 handleTicketsKeyRotation(time(nullptr));
797 }
798 else {
799 GnuTLSIOCtx::loadTicketsKeys(fe.d_tlsConfig.d_ticketKeyFile);
800 }
801 }
802 catch(const std::runtime_error& e) {
803 throw std::runtime_error("Error generating tickets key for TLS context on " + fe.d_addr.toStringWithPort() + ": " + e.what());
804 }
805 }
806
807 virtual ~GnuTLSIOCtx() override
808 {
809 d_creds.reset();
810
811 if (d_priorityCache) {
812 gnutls_priority_deinit(d_priorityCache);
813 }
814 }
815
816 std::unique_ptr<TLSConnection> getConnection(int socket, unsigned int timeout, time_t now) override
817 {
818 handleTicketsKeyRotation(now);
819
820 std::shared_ptr<GnuTLSTicketsKey> ticketsKey;
821 {
822 ReadLock rl(&d_lock);
823 ticketsKey = d_ticketsKey;
824 }
825
826 return std::unique_ptr<GnuTLSConnection>(new GnuTLSConnection(socket, timeout, d_creds.get(), d_priorityCache, ticketsKey, d_enableTickets));
827 }
828
829 void rotateTicketsKey(time_t now) override
830 {
831 if (!d_enableTickets) {
832 return;
833 }
834
835 auto newKey = std::make_shared<GnuTLSTicketsKey>();
836
837 {
838 WriteLock wl(&d_lock);
839 d_ticketsKey = newKey;
840 }
841
842 if (d_ticketsKeyRotationDelay > 0) {
843 d_ticketsKeyNextRotation = now + d_ticketsKeyRotationDelay;
844 }
845 }
846
847 void loadTicketsKeys(const std::string& file) override final
848 {
849 if (!d_enableTickets) {
850 return;
851 }
852
853 auto newKey = std::make_shared<GnuTLSTicketsKey>(file);
854 {
855 WriteLock wl(&d_lock);
856 d_ticketsKey = newKey;
857 }
858
859 if (d_ticketsKeyRotationDelay > 0) {
860 d_ticketsKeyNextRotation = time(nullptr) + d_ticketsKeyRotationDelay;
861 }
862 }
863
864 size_t getTicketsKeysCount() override
865 {
866 ReadLock rl(&d_lock);
867 return d_ticketsKey != nullptr ? 1 : 0;
868 }
869
870 private:
871 std::unique_ptr<gnutls_certificate_credentials_st, void(*)(gnutls_certificate_credentials_t)> d_creds;
872 gnutls_priority_t d_priorityCache{nullptr};
873 std::shared_ptr<GnuTLSTicketsKey> d_ticketsKey{nullptr};
874 ReadWriteLock d_lock;
875 bool d_enableTickets{true};
876 };
877
878 #endif /* HAVE_GNUTLS */
879
880 #endif /* HAVE_DNS_OVER_TLS */
881
882 bool TLSFrontend::setupTLS()
883 {
884 #ifdef HAVE_DNS_OVER_TLS
885 /* get the "best" available provider */
886 if (!d_provider.empty()) {
887 #ifdef HAVE_GNUTLS
888 if (d_provider == "gnutls") {
889 d_ctx = std::make_shared<GnuTLSIOCtx>(*this);
890 return true;
891 }
892 #endif /* HAVE_GNUTLS */
893 #ifdef HAVE_LIBSSL
894 if (d_provider == "openssl") {
895 d_ctx = std::make_shared<OpenSSLTLSIOCtx>(*this);
896 return true;
897 }
898 #endif /* HAVE_LIBSSL */
899 }
900 #ifdef HAVE_LIBSSL
901 d_ctx = std::make_shared<OpenSSLTLSIOCtx>(*this);
902 #else /* HAVE_LIBSSL */
903 #ifdef HAVE_GNUTLS
904 d_ctx = std::make_shared<GnuTLSIOCtx>(*this);
905 #endif /* HAVE_GNUTLS */
906 #endif /* HAVE_LIBSSL */
907
908 #endif /* HAVE_DNS_OVER_TLS */
909 return true;
910 }