]> git.ipfire.org Git - thirdparty/pdns.git/blob - pdns/dnsdistdist/tcpiohandler.cc
dnsdist: Add missing overrides
[thirdparty/pdns.git] / pdns / dnsdistdist / tcpiohandler.cc
1 #include <fstream>
2
3 #include "config.h"
4 #include "dolog.hh"
5 #include "iputils.hh"
6 #include "lock.hh"
7 #include "tcpiohandler.hh"
8
9 #ifdef HAVE_LIBSODIUM
10 #include <sodium.h>
11 #endif /* HAVE_LIBSODIUM */
12
13 #ifdef HAVE_DNS_OVER_TLS
14 #ifdef HAVE_LIBSSL
15 #include <openssl/conf.h>
16 #include <openssl/err.h>
17 #include <openssl/rand.h>
18 #include <openssl/ssl.h>
19
20 #include <boost/circular_buffer.hpp>
21
22 #if (OPENSSL_VERSION_NUMBER < 0x1010000fL || defined LIBRESSL_VERSION_NUMBER)
23 /* OpenSSL < 1.1.0 needs support for threading/locking in the calling application. */
24 static pthread_mutex_t *openssllocks{nullptr};
25
26 extern "C" {
27 static void openssl_pthreads_locking_callback(int mode, int type, const char *file, int line)
28 {
29 if (mode & CRYPTO_LOCK) {
30 pthread_mutex_lock(&(openssllocks[type]));
31
32 } else {
33 pthread_mutex_unlock(&(openssllocks[type]));
34 }
35 }
36
37 static unsigned long openssl_pthreads_id_callback()
38 {
39 return (unsigned long)pthread_self();
40 }
41 }
42
43 static void openssl_thread_setup()
44 {
45 openssllocks = (pthread_mutex_t*)OPENSSL_malloc(CRYPTO_num_locks() * sizeof(pthread_mutex_t));
46
47 for (int i = 0; i < CRYPTO_num_locks(); i++)
48 pthread_mutex_init(&(openssllocks[i]), NULL);
49
50 CRYPTO_set_id_callback(openssl_pthreads_id_callback);
51 CRYPTO_set_locking_callback(openssl_pthreads_locking_callback);
52 }
53
54 static void openssl_thread_cleanup()
55 {
56 CRYPTO_set_locking_callback(NULL);
57
58 for (int i=0; i<CRYPTO_num_locks(); i++) {
59 pthread_mutex_destroy(&(openssllocks[i]));
60 }
61
62 OPENSSL_free(openssllocks);
63 }
64
65 #else
66 static void openssl_thread_setup()
67 {
68 }
69
70 static void openssl_thread_cleanup()
71 {
72 }
73 #endif /* (OPENSSL_VERSION_NUMBER < 0x1010000fL || defined LIBRESSL_VERSION_NUMBER) */
74
75 /* From rfc5077 Section 4. Recommended Ticket Construction */
76 #define TLS_TICKETS_KEY_NAME_SIZE (16)
77
78 /* AES-256 */
79 #define TLS_TICKETS_CIPHER_KEY_SIZE (32)
80 #define TLS_TICKETS_CIPHER_ALGO (EVP_aes_256_cbc)
81
82 /* HMAC SHA-256 */
83 #define TLS_TICKETS_MAC_KEY_SIZE (32)
84 #define TLS_TICKETS_MAC_ALGO (EVP_sha256)
85
86 static int s_ticketsKeyIndex{-1};
87
88 class OpenSSLTLSTicketKey
89 {
90 public:
91 OpenSSLTLSTicketKey()
92 {
93 if (RAND_bytes(d_name, sizeof(d_name)) != 1) {
94 throw std::runtime_error("Error while generating the name of the OpenSSL TLS ticket key");
95 }
96
97 if (RAND_bytes(d_cipherKey, sizeof(d_cipherKey)) != 1) {
98 throw std::runtime_error("Error while generating the cipher key of the OpenSSL TLS ticket key");
99 }
100
101 if (RAND_bytes(d_hmacKey, sizeof(d_hmacKey)) != 1) {
102 throw std::runtime_error("Error while generating the HMAC key of the OpenSSL TLS ticket key");
103 }
104 #ifdef HAVE_LIBSODIUM
105 sodium_mlock(d_name, sizeof(d_name));
106 sodium_mlock(d_cipherKey, sizeof(d_cipherKey));
107 sodium_mlock(d_hmacKey, sizeof(d_hmacKey));
108 #endif /* HAVE_LIBSODIUM */
109 }
110
111 OpenSSLTLSTicketKey(ifstream& file)
112 {
113 file.read(reinterpret_cast<char*>(d_name), sizeof(d_name));
114 file.read(reinterpret_cast<char*>(d_cipherKey), sizeof(d_cipherKey));
115 file.read(reinterpret_cast<char*>(d_hmacKey), sizeof(d_hmacKey));
116
117 if (file.fail()) {
118 throw std::runtime_error("Unable to load a ticket key from the OpenSSL tickets key file");
119 }
120 #ifdef HAVE_LIBSODIUM
121 sodium_mlock(d_name, sizeof(d_name));
122 sodium_mlock(d_cipherKey, sizeof(d_cipherKey));
123 sodium_mlock(d_hmacKey, sizeof(d_hmacKey));
124 #endif /* HAVE_LIBSODIUM */
125 }
126
127 ~OpenSSLTLSTicketKey()
128 {
129 #ifdef HAVE_LIBSODIUM
130 sodium_munlock(d_name, sizeof(d_name));
131 sodium_munlock(d_cipherKey, sizeof(d_cipherKey));
132 sodium_munlock(d_hmacKey, sizeof(d_hmacKey));
133 #else
134 OPENSSL_cleanse(d_name, sizeof(d_name));
135 OPENSSL_cleanse(d_cipherKey, sizeof(d_cipherKey));
136 OPENSSL_cleanse(d_hmacKey, sizeof(d_hmacKey));
137 #endif /* HAVE_LIBSODIUM */
138 }
139
140 bool nameMatches(const unsigned char name[TLS_TICKETS_KEY_NAME_SIZE]) const
141 {
142 return (memcmp(d_name, name, sizeof(d_name)) == 0);
143 }
144
145 int encrypt(unsigned char keyName[TLS_TICKETS_KEY_NAME_SIZE], unsigned char *iv, EVP_CIPHER_CTX *ectx, HMAC_CTX *hctx) const
146 {
147 memcpy(keyName, d_name, sizeof(d_name));
148
149 if (RAND_bytes(iv, EVP_MAX_IV_LENGTH) != 1) {
150 return -1;
151 }
152
153 if (EVP_EncryptInit_ex(ectx, TLS_TICKETS_CIPHER_ALGO(), nullptr, d_cipherKey, iv) != 1) {
154 return -1;
155 }
156
157 if (HMAC_Init_ex(hctx, d_hmacKey, sizeof(d_hmacKey), TLS_TICKETS_MAC_ALGO(), nullptr) != 1) {
158 return -1;
159 }
160
161 return 1;
162 }
163
164 bool decrypt(const unsigned char* iv, EVP_CIPHER_CTX *ectx, HMAC_CTX *hctx) const
165 {
166 if (HMAC_Init_ex(hctx, d_hmacKey, sizeof(d_hmacKey), TLS_TICKETS_MAC_ALGO(), nullptr) != 1) {
167 return false;
168 }
169
170 if (EVP_DecryptInit_ex(ectx, TLS_TICKETS_CIPHER_ALGO(), nullptr, d_cipherKey, iv) != 1) {
171 return false;
172 }
173
174 return true;
175 }
176
177 private:
178 unsigned char d_name[TLS_TICKETS_KEY_NAME_SIZE];
179 unsigned char d_cipherKey[TLS_TICKETS_CIPHER_KEY_SIZE];
180 unsigned char d_hmacKey[TLS_TICKETS_MAC_KEY_SIZE];
181 };
182
183 class OpenSSLTLSTicketKeysRing
184 {
185 public:
186 OpenSSLTLSTicketKeysRing(size_t capacity)
187 {
188 pthread_rwlock_init(&d_lock, nullptr);
189 d_ticketKeys.set_capacity(capacity);
190 }
191
192 ~OpenSSLTLSTicketKeysRing()
193 {
194 pthread_rwlock_destroy(&d_lock);
195 }
196
197 void addKey(std::shared_ptr<OpenSSLTLSTicketKey> newKey)
198 {
199 WriteLock wl(&d_lock);
200 d_ticketKeys.push_back(newKey);
201 }
202
203 std::shared_ptr<OpenSSLTLSTicketKey> getEncryptionKey()
204 {
205 ReadLock rl(&d_lock);
206 return d_ticketKeys.front();
207 }
208
209 std::shared_ptr<OpenSSLTLSTicketKey> getDecryptionKey(unsigned char name[TLS_TICKETS_KEY_NAME_SIZE], bool& activeKey)
210 {
211 ReadLock rl(&d_lock);
212 for (auto& key : d_ticketKeys) {
213 if (key->nameMatches(name)) {
214 activeKey = (key == d_ticketKeys.front());
215 return key;
216 }
217 }
218 return nullptr;
219 }
220
221 size_t getKeysCount()
222 {
223 ReadLock rl(&d_lock);
224 return d_ticketKeys.size();
225 }
226
227 private:
228 boost::circular_buffer<std::shared_ptr<OpenSSLTLSTicketKey> > d_ticketKeys;
229 pthread_rwlock_t d_lock;
230 };
231
232 class OpenSSLTLSConnection: public TLSConnection
233 {
234 public:
235 OpenSSLTLSConnection(int socket, unsigned int timeout, SSL_CTX* tlsCtx): d_conn(std::unique_ptr<SSL, void(*)(SSL*)>(SSL_new(tlsCtx), SSL_free)), d_timeout(timeout)
236 {
237 d_socket = socket;
238
239 if (!d_conn) {
240 vinfolog("Error creating TLS object");
241 if (g_verbose) {
242 ERR_print_errors_fp(stderr);
243 }
244 throw std::runtime_error("Error creating TLS object");
245 }
246
247 if (!SSL_set_fd(d_conn.get(), d_socket)) {
248 throw std::runtime_error("Error assigning socket");
249 }
250 }
251
252 IOState convertIORequestToIOState(int res) const
253 {
254 int error = SSL_get_error(d_conn.get(), res);
255 if (error == SSL_ERROR_WANT_READ) {
256 return IOState::NeedRead;
257 }
258 else if (error == SSL_ERROR_WANT_WRITE) {
259 return IOState::NeedWrite;
260 }
261 else if (error == SSL_ERROR_SYSCALL) {
262 throw std::runtime_error("Error while processing TLS connection:" + std::string(strerror(errno)));
263 }
264 else {
265 throw std::runtime_error("Error while processing TLS connection:" + std::to_string(error));
266 }
267 }
268
269 void handleIORequest(int res, unsigned int timeout)
270 {
271 auto state = convertIORequestToIOState(res);
272 if (state == IOState::NeedRead) {
273 res = waitForData(d_socket, timeout);
274 if (res <= 0) {
275 throw std::runtime_error("Error reading from TLS connection");
276 }
277 }
278 else if (state == IOState::NeedWrite) {
279 res = waitForRWData(d_socket, false, timeout, 0);
280 if (res <= 0) {
281 throw std::runtime_error("Error waiting to write to TLS connection");
282 }
283 }
284 }
285
286 IOState tryHandshake() override
287 {
288 int res = SSL_accept(d_conn.get());
289 if (res == 1) {
290 return IOState::Done;
291 }
292 else if (res < 0) {
293 return convertIORequestToIOState(res);
294 }
295
296 throw std::runtime_error("Error accepting TLS connection");
297 }
298
299 void doHandshake() override
300 {
301 int res = 0;
302 do {
303 res = SSL_accept(d_conn.get());
304 if (res < 0) {
305 handleIORequest(res, d_timeout);
306 }
307 }
308 while (res < 0);
309
310 if (res != 1) {
311 throw std::runtime_error("Error accepting TLS connection");
312 }
313 }
314
315 IOState tryWrite(std::vector<uint8_t>& buffer, size_t& pos, size_t toWrite) override
316 {
317 do {
318 int res = SSL_write(d_conn.get(), reinterpret_cast<const char *>(&buffer.at(pos)), static_cast<int>(toWrite - pos));
319 if (res == 0) {
320 throw std::runtime_error("Error writing to TLS connection");
321 }
322 else if (res < 0) {
323 return convertIORequestToIOState(res);
324 }
325 else {
326 pos += static_cast<size_t>(res);
327 }
328 }
329 while (pos < toWrite);
330 return IOState::Done;
331 }
332
333 IOState tryRead(std::vector<uint8_t>& buffer, size_t& pos, size_t toRead) override
334 {
335 do {
336 int res = SSL_read(d_conn.get(), reinterpret_cast<char *>(&buffer.at(pos)), static_cast<int>(toRead - pos));
337 if (res == 0) {
338 throw std::runtime_error("Error reading from TLS connection");
339 }
340 else if (res < 0) {
341 return convertIORequestToIOState(res);
342 }
343 else {
344 pos += static_cast<size_t>(res);
345 }
346 }
347 while (pos < toRead);
348 return IOState::Done;
349 }
350
351 size_t read(void* buffer, size_t bufferSize, unsigned int readTimeout, unsigned int totalTimeout) override
352 {
353 size_t got = 0;
354 time_t start = 0;
355 unsigned int remainingTime = totalTimeout;
356 if (totalTimeout) {
357 start = time(nullptr);
358 }
359
360 do {
361 int res = SSL_read(d_conn.get(), (reinterpret_cast<char *>(buffer) + got), static_cast<int>(bufferSize - got));
362 if (res == 0) {
363 throw std::runtime_error("Error reading from TLS connection");
364 }
365 else if (res < 0) {
366 handleIORequest(res, readTimeout);
367 }
368 else {
369 got += static_cast<size_t>(res);
370 }
371
372 if (totalTimeout) {
373 time_t now = time(nullptr);
374 unsigned int elapsed = now - start;
375 if (now < start || elapsed >= remainingTime) {
376 throw runtime_error("Timeout while reading data");
377 }
378 start = now;
379 remainingTime -= elapsed;
380 }
381 }
382 while (got < bufferSize);
383
384 return got;
385 }
386
387 size_t write(const void* buffer, size_t bufferSize, unsigned int writeTimeout) override
388 {
389 size_t got = 0;
390 do {
391 int res = SSL_write(d_conn.get(), (reinterpret_cast<const char *>(buffer) + got), static_cast<int>(bufferSize - got));
392 if (res == 0) {
393 throw std::runtime_error("Error writing to TLS connection");
394 }
395 else if (res < 0) {
396 handleIORequest(res, writeTimeout);
397 }
398 else {
399 got += static_cast<size_t>(res);
400 }
401 }
402 while (got < bufferSize);
403
404 return got;
405 }
406 void close() override
407 {
408 if (d_conn) {
409 SSL_shutdown(d_conn.get());
410 }
411 }
412
413 private:
414 std::unique_ptr<SSL, void(*)(SSL*)> d_conn;
415 unsigned int d_timeout;
416 };
417
418 class OpenSSLTLSIOCtx: public TLSCtx
419 {
420 public:
421 OpenSSLTLSIOCtx(const TLSFrontend& fe): d_ticketKeys(fe.d_numberOfTicketsKeys), d_tlsCtx(std::unique_ptr<SSL_CTX, void(*)(SSL_CTX*)>(nullptr, SSL_CTX_free))
422 {
423 d_ticketsKeyRotationDelay = fe.d_ticketsKeyRotationDelay;
424
425 int sslOptions =
426 SSL_OP_NO_SSLv2 |
427 SSL_OP_NO_SSLv3 |
428 SSL_OP_NO_COMPRESSION |
429 SSL_OP_NO_SESSION_RESUMPTION_ON_RENEGOTIATION |
430 SSL_OP_SINGLE_DH_USE |
431 SSL_OP_SINGLE_ECDH_USE |
432 SSL_OP_CIPHER_SERVER_PREFERENCE;
433
434 if (!fe.d_enableTickets) {
435 sslOptions |= SSL_OP_NO_TICKET;
436 }
437
438 if (s_users.fetch_add(1) == 0) {
439 ERR_load_crypto_strings();
440 OpenSSL_add_ssl_algorithms();
441 openssl_thread_setup();
442
443 s_ticketsKeyIndex = SSL_CTX_get_ex_new_index(0, nullptr, nullptr, nullptr, nullptr);
444
445 if (s_ticketsKeyIndex == -1) {
446 throw std::runtime_error("Error getting an index for tickets key");
447 }
448 }
449
450 d_tlsCtx = std::unique_ptr<SSL_CTX, void(*)(SSL_CTX*)>(SSL_CTX_new(SSLv23_server_method()), SSL_CTX_free);
451 if (!d_tlsCtx) {
452 ERR_print_errors_fp(stderr);
453 throw std::runtime_error("Error creating TLS context on " + fe.d_addr.toStringWithPort());
454 }
455
456 /* use our own ticket keys handler so we can rotate them */
457 SSL_CTX_set_tlsext_ticket_key_cb(d_tlsCtx.get(), &OpenSSLTLSIOCtx::ticketKeyCb);
458 SSL_CTX_set_ex_data(d_tlsCtx.get(), s_ticketsKeyIndex, this);
459 SSL_CTX_set_options(d_tlsCtx.get(), sslOptions);
460 #if defined(SSL_CTX_set_ecdh_auto)
461 SSL_CTX_set_ecdh_auto(d_tlsCtx.get(), 1);
462 #endif
463 if (fe.d_maxStoredSessions == 0) {
464 /* disable stored sessions entirely */
465 SSL_CTX_set_session_cache_mode(d_tlsCtx.get(), SSL_SESS_CACHE_OFF);
466 }
467 else {
468 /* use the internal built-in cache to store sessions */
469 SSL_CTX_set_session_cache_mode(d_tlsCtx.get(), SSL_SESS_CACHE_SERVER);
470 SSL_CTX_sess_set_cache_size(d_tlsCtx.get(), fe.d_maxStoredSessions);
471 }
472
473 for (const auto& pair : fe.d_certKeyPairs) {
474 if (SSL_CTX_use_certificate_chain_file(d_tlsCtx.get(), pair.first.c_str()) != 1) {
475 ERR_print_errors_fp(stderr);
476 throw std::runtime_error("Error loading certificate from " + pair.first + " for the TLS context on " + fe.d_addr.toStringWithPort());
477 }
478 if (SSL_CTX_use_PrivateKey_file(d_tlsCtx.get(), pair.second.c_str(), SSL_FILETYPE_PEM) != 1) {
479 ERR_print_errors_fp(stderr);
480 throw std::runtime_error("Error loading key from " + pair.second + " for the TLS context on " + fe.d_addr.toStringWithPort());
481 }
482 }
483
484 if (!fe.d_ciphers.empty()) {
485 if (SSL_CTX_set_cipher_list(d_tlsCtx.get(), fe.d_ciphers.c_str()) != 1) {
486 ERR_print_errors_fp(stderr);
487 throw std::runtime_error("Error setting the cipher list to '" + fe.d_ciphers + "' for the TLS context on " + fe.d_addr.toStringWithPort());
488 }
489 }
490
491 try {
492 if (fe.d_ticketKeyFile.empty()) {
493 handleTicketsKeyRotation(time(nullptr));
494 }
495 else {
496 loadTicketsKeys(fe.d_ticketKeyFile);
497 }
498 }
499 catch (const std::exception& e) {
500 throw;
501 }
502 }
503
504 virtual ~OpenSSLTLSIOCtx() override
505 {
506 d_tlsCtx.reset();
507
508 if (s_users.fetch_sub(1) == 1) {
509 ERR_free_strings();
510
511 EVP_cleanup();
512
513 CONF_modules_finish();
514 CONF_modules_free();
515 CONF_modules_unload(1);
516
517 CRYPTO_cleanup_all_ex_data();
518 openssl_thread_cleanup();
519 }
520 }
521
522 static int ticketKeyCb(SSL *s, unsigned char keyName[TLS_TICKETS_KEY_NAME_SIZE], unsigned char *iv, EVP_CIPHER_CTX *ectx, HMAC_CTX *hctx, int enc)
523 {
524 SSL_CTX* sslCtx = SSL_get_SSL_CTX(s);
525 if (sslCtx == nullptr) {
526 return -1;
527 }
528
529 OpenSSLTLSIOCtx* ctx = reinterpret_cast<OpenSSLTLSIOCtx*>(SSL_CTX_get_ex_data(sslCtx, s_ticketsKeyIndex));
530 if (ctx == nullptr) {
531 return -1;
532 }
533
534 if (enc) {
535 const auto key = ctx->d_ticketKeys.getEncryptionKey();
536 if (key == nullptr) {
537 return -1;
538 }
539
540 return key->encrypt(keyName, iv, ectx, hctx);
541 }
542
543 bool activeEncryptionKey = false;
544
545 const auto key = ctx->d_ticketKeys.getDecryptionKey(keyName, activeEncryptionKey);
546 if (key == nullptr) {
547 /* we don't know this key, just create a new ticket */
548 return 0;
549 }
550
551 if (key->decrypt(iv, ectx, hctx) == false) {
552 return -1;
553 }
554
555 if (!activeEncryptionKey) {
556 /* this key is not active, please encrypt the ticket content with the currently active one */
557 return 2;
558 }
559
560 return 1;
561 }
562
563 std::unique_ptr<TLSConnection> getConnection(int socket, unsigned int timeout, time_t now) override
564 {
565 handleTicketsKeyRotation(now);
566
567 return std::unique_ptr<OpenSSLTLSConnection>(new OpenSSLTLSConnection(socket, timeout, d_tlsCtx.get()));
568 }
569
570 void rotateTicketsKey(time_t now) override
571 {
572 auto newKey = std::make_shared<OpenSSLTLSTicketKey>();
573 d_ticketKeys.addKey(newKey);
574
575 if (d_ticketsKeyRotationDelay > 0) {
576 d_ticketsKeyNextRotation = now + d_ticketsKeyRotationDelay;
577 }
578 }
579
580 void loadTicketsKeys(const std::string& keyFile) override
581 {
582 bool keyLoaded = false;
583 ifstream file(keyFile);
584 try {
585 do {
586 auto newKey = std::make_shared<OpenSSLTLSTicketKey>(file);
587 d_ticketKeys.addKey(newKey);
588 keyLoaded = true;
589 }
590 while (!file.fail());
591 }
592 catch (const std::exception& e) {
593 /* if we haven't been able to load at least one key, fail */
594 if (!keyLoaded) {
595 throw;
596 }
597 }
598
599 if (d_ticketsKeyRotationDelay > 0) {
600 d_ticketsKeyNextRotation = time(nullptr) + d_ticketsKeyRotationDelay;
601 }
602
603 file.close();
604 }
605
606 size_t getTicketsKeysCount() override
607 {
608 return d_ticketKeys.getKeysCount();
609 }
610
611 private:
612 OpenSSLTLSTicketKeysRing d_ticketKeys;
613 std::unique_ptr<SSL_CTX, void(*)(SSL_CTX*)> d_tlsCtx;
614 static std::atomic<uint64_t> s_users;
615 };
616
617 std::atomic<uint64_t> OpenSSLTLSIOCtx::s_users(0);
618
619 #endif /* HAVE_LIBSSL */
620
621 #ifdef HAVE_GNUTLS
622 #include <gnutls/gnutls.h>
623 #include <gnutls/x509.h>
624
625 void safe_memory_lock(void* data, size_t size)
626 {
627 #ifdef HAVE_LIBSODIUM
628 sodium_mlock(data, size);
629 #endif
630 }
631
632 void safe_memory_release(void* data, size_t size)
633 {
634 #ifdef HAVE_LIBSODIUM
635 sodium_munlock(data, size);
636 #elif defined(HAVE_EXPLICIT_BZERO)
637 explicit_bzero(data, size);
638 #elif defined(HAVE_EXPLICIT_MEMSET)
639 explicit_memset(data, 0, size);
640 #elif defined(HAVE_GNUTLS_MEMSET)
641 gnutls_memset(data, 0, size);
642 #else
643 /* shamelessly taken from Dovecot's src/lib/safe-memset.c */
644 volatile unsigned int volatile_zero_idx = 0;
645 volatile unsigned char *p = reinterpret_cast<volatile unsigned char *>(data);
646
647 if (size == 0)
648 return;
649
650 do {
651 memset(data, 0, size);
652 } while (p[volatile_zero_idx] != 0);
653 #endif
654 }
655
656 class GnuTLSTicketsKey
657 {
658 public:
659 GnuTLSTicketsKey()
660 {
661 if (gnutls_session_ticket_key_generate(&d_key) != GNUTLS_E_SUCCESS) {
662 throw std::runtime_error("Error generating tickets key for TLS context");
663 }
664
665 safe_memory_lock(d_key.data, d_key.size);
666 }
667
668 GnuTLSTicketsKey(const std::string& keyFile)
669 {
670 /* to be sure we are loading the correct amount of data, which
671 may change between versions, let's generate a correct key first */
672 if (gnutls_session_ticket_key_generate(&d_key) != GNUTLS_E_SUCCESS) {
673 throw std::runtime_error("Error generating tickets key (before parsing key file) for TLS context");
674 }
675
676 safe_memory_lock(d_key.data, d_key.size);
677
678 try {
679 ifstream file(keyFile);
680 file.read(reinterpret_cast<char*>(d_key.data), d_key.size);
681
682 if (file.fail()) {
683 file.close();
684 throw std::runtime_error("Invalid GnuTLS tickets key file " + keyFile);
685 }
686
687 file.close();
688 }
689 catch (const std::exception& e) {
690 safe_memory_release(d_key.data, d_key.size);
691 gnutls_free(d_key.data);
692 d_key.data = nullptr;
693 throw;
694 }
695 }
696
697 ~GnuTLSTicketsKey()
698 {
699 if (d_key.data != nullptr && d_key.size > 0) {
700 safe_memory_release(d_key.data, d_key.size);
701 }
702 gnutls_free(d_key.data);
703 d_key.data = nullptr;
704 }
705 const gnutls_datum_t& getKey() const
706 {
707 return d_key;
708 }
709
710 private:
711 gnutls_datum_t d_key{nullptr, 0};
712 };
713
714 class GnuTLSConnection: public TLSConnection
715 {
716 public:
717
718 GnuTLSConnection(int socket, unsigned int timeout, const gnutls_certificate_credentials_t creds, const gnutls_priority_t priorityCache, std::shared_ptr<GnuTLSTicketsKey>& ticketsKey, bool enableTickets): d_conn(std::unique_ptr<gnutls_session_int, void(*)(gnutls_session_t)>(nullptr, gnutls_deinit)), d_ticketsKey(ticketsKey)
719 {
720 unsigned int sslOptions = GNUTLS_SERVER | GNUTLS_NONBLOCK;
721 #ifdef GNUTLS_NO_SIGNAL
722 sslOptions |= GNUTLS_NO_SIGNAL;
723 #endif
724
725 d_socket = socket;
726
727 gnutls_session_t conn;
728 if (gnutls_init(&conn, sslOptions) != GNUTLS_E_SUCCESS) {
729 throw std::runtime_error("Error creating TLS connection");
730 }
731
732 d_conn = std::unique_ptr<gnutls_session_int, void(*)(gnutls_session_t)>(conn, gnutls_deinit);
733 conn = nullptr;
734
735 if (gnutls_credentials_set(d_conn.get(), GNUTLS_CRD_CERTIFICATE, creds) != GNUTLS_E_SUCCESS) {
736 throw std::runtime_error("Error setting certificate and key to TLS connection");
737 }
738
739 if (gnutls_priority_set(d_conn.get(), priorityCache) != GNUTLS_E_SUCCESS) {
740 throw std::runtime_error("Error setting ciphers to TLS connection");
741 }
742
743 if (enableTickets && d_ticketsKey) {
744 const gnutls_datum_t& key = d_ticketsKey->getKey();
745 if (gnutls_session_ticket_enable_server(d_conn.get(), &key) != GNUTLS_E_SUCCESS) {
746 throw std::runtime_error("Error setting the tickets key to TLS connection");
747 }
748 }
749
750 gnutls_transport_set_int(d_conn.get(), d_socket);
751
752 /* timeouts are in milliseconds */
753 gnutls_handshake_set_timeout(d_conn.get(), timeout * 1000);
754 gnutls_record_set_timeout(d_conn.get(), timeout * 1000);
755 }
756
757 void doHandshake() override
758 {
759 int ret = 0;
760 do {
761 ret = gnutls_handshake(d_conn.get());
762 if (gnutls_error_is_fatal(ret) || ret == GNUTLS_E_WARNING_ALERT_RECEIVED) {
763 throw std::runtime_error("Error accepting a new connection");
764 }
765 }
766 while (ret < 0 && ret == GNUTLS_E_INTERRUPTED);
767 }
768
769 IOState tryHandshake() override
770 {
771 int ret = 0;
772
773 do {
774 ret = gnutls_handshake(d_conn.get());
775 if (ret == GNUTLS_E_SUCCESS) {
776 return IOState::Done;
777 }
778 else if (ret == GNUTLS_E_AGAIN) {
779 return IOState::NeedRead;
780 }
781 else if (gnutls_error_is_fatal(ret) || ret == GNUTLS_E_WARNING_ALERT_RECEIVED) {
782 throw std::runtime_error("Error accepting a new connection");
783 }
784 } while (ret == GNUTLS_E_INTERRUPTED);
785
786 throw std::runtime_error("Error accepting a new connection");
787 }
788
789 IOState tryWrite(std::vector<uint8_t>& buffer, size_t& pos, size_t toWrite) override
790 {
791 do {
792 ssize_t res = gnutls_record_send(d_conn.get(), reinterpret_cast<const char *>(&buffer.at(pos)), toWrite - pos);
793 if (res == 0) {
794 throw std::runtime_error("Error writing to TLS connection");
795 }
796 else if (res > 0) {
797 pos += static_cast<size_t>(res);
798 }
799 else if (res < 0) {
800 if (gnutls_error_is_fatal(res)) {
801 throw std::runtime_error("Error writing to TLS connection");
802 }
803 else if (res == GNUTLS_E_AGAIN) {
804 return IOState::NeedWrite;
805 }
806 warnlog("Warning, non-fatal error while writing to TLS connection: %s", gnutls_strerror(res));
807 }
808 }
809 while (pos < toWrite);
810 return IOState::Done;
811 }
812
813 IOState tryRead(std::vector<uint8_t>& buffer, size_t& pos, size_t toRead) override
814 {
815 do {
816 ssize_t res = gnutls_record_recv(d_conn.get(), reinterpret_cast<char *>(&buffer.at(pos)), toRead - pos);
817 if (res == 0) {
818 throw std::runtime_error("Error reading from TLS connection");
819 }
820 else if (res > 0) {
821 pos += static_cast<size_t>(res);
822 }
823 else if (res < 0) {
824 if (gnutls_error_is_fatal(res)) {
825 throw std::runtime_error("Error reading from TLS connection");
826 }
827 else if (res == GNUTLS_E_AGAIN) {
828 return IOState::NeedRead;
829 }
830 warnlog("Warning, non-fatal error while writing to TLS connection: %s", gnutls_strerror(res));
831 }
832 }
833 while (pos < toRead);
834 return IOState::Done;
835 }
836
837 size_t read(void* buffer, size_t bufferSize, unsigned int readTimeout, unsigned int totalTimeout) override
838 {
839 size_t got = 0;
840 time_t start = 0;
841 unsigned int remainingTime = totalTimeout;
842 if (totalTimeout) {
843 start = time(nullptr);
844 }
845
846 do {
847 ssize_t res = gnutls_record_recv(d_conn.get(), (reinterpret_cast<char *>(buffer) + got), bufferSize - got);
848 if (res == 0) {
849 throw std::runtime_error("Error reading from TLS connection");
850 }
851 else if (res > 0) {
852 got += static_cast<size_t>(res);
853 }
854 else if (res < 0) {
855 if (gnutls_error_is_fatal(res)) {
856 throw std::runtime_error("Error reading from TLS connection:" + std::string(gnutls_strerror(res)));
857 }
858 else if (res == GNUTLS_E_AGAIN) {
859 int result = waitForData(d_socket, readTimeout);
860 if (result <= 0) {
861 throw std::runtime_error("Error reading from TLS connection: " + std::to_string(result));
862 }
863 }
864 else {
865 vinfolog("Non-fatal error while reading from TLS connection: %s", gnutls_strerror(res));
866 }
867 }
868
869 if (totalTimeout) {
870 time_t now = time(nullptr);
871 unsigned int elapsed = now - start;
872 if (now < start || elapsed >= remainingTime) {
873 throw runtime_error("Timeout while reading data");
874 }
875 start = now;
876 remainingTime -= elapsed;
877 }
878 }
879 while (got < bufferSize);
880
881 return got;
882 }
883
884 size_t write(const void* buffer, size_t bufferSize, unsigned int writeTimeout) override
885 {
886 size_t got = 0;
887
888 do {
889 ssize_t res = gnutls_record_send(d_conn.get(), (reinterpret_cast<const char *>(buffer) + got), bufferSize - got);
890 if (res == 0) {
891 throw std::runtime_error("Error writing to TLS connection");
892 }
893 else if (res > 0) {
894 got += static_cast<size_t>(res);
895 }
896 else if (res < 0) {
897 if (gnutls_error_is_fatal(res)) {
898 throw std::runtime_error("Error writing to TLS connection: " + std::string(gnutls_strerror(res)));
899 }
900 else if (res == GNUTLS_E_AGAIN) {
901 int result = waitForRWData(d_socket, false, writeTimeout, 0);
902 if (result <= 0) {
903 throw std::runtime_error("Error waiting to write to TLS connection: " + std::to_string(result));
904 }
905 }
906 else {
907 vinfolog("Non-fatal error while writing to TLS connection: %s", gnutls_strerror(res));
908 }
909 }
910 }
911 while (got < bufferSize);
912
913 return got;
914 }
915
916 void close() override
917 {
918 if (d_conn) {
919 gnutls_bye(d_conn.get(), GNUTLS_SHUT_WR);
920 }
921 }
922
923 private:
924 std::unique_ptr<gnutls_session_int, void(*)(gnutls_session_t)> d_conn;
925 std::shared_ptr<GnuTLSTicketsKey> d_ticketsKey;
926 };
927
928 class GnuTLSIOCtx: public TLSCtx
929 {
930 public:
931 GnuTLSIOCtx(const TLSFrontend& fe): d_creds(std::unique_ptr<gnutls_certificate_credentials_st, void(*)(gnutls_certificate_credentials_t)>(nullptr, gnutls_certificate_free_credentials)), d_enableTickets(fe.d_enableTickets)
932 {
933 int rc = 0;
934 d_ticketsKeyRotationDelay = fe.d_ticketsKeyRotationDelay;
935
936 gnutls_certificate_credentials_t creds;
937 rc = gnutls_certificate_allocate_credentials(&creds);
938 if (rc != GNUTLS_E_SUCCESS) {
939 throw std::runtime_error("Error allocating credentials for TLS context on " + fe.d_addr.toStringWithPort() + ": " + gnutls_strerror(rc));
940 }
941
942 d_creds = std::unique_ptr<gnutls_certificate_credentials_st, void(*)(gnutls_certificate_credentials_t)>(creds, gnutls_certificate_free_credentials);
943 creds = nullptr;
944
945 for (const auto& pair : fe.d_certKeyPairs) {
946 rc = gnutls_certificate_set_x509_key_file(d_creds.get(), pair.first.c_str(), pair.second.c_str(), GNUTLS_X509_FMT_PEM);
947 if (rc != GNUTLS_E_SUCCESS) {
948 throw std::runtime_error("Error loading certificate ('" + pair.first + "') and key ('" + pair.second + "') for TLS context on " + fe.d_addr.toStringWithPort() + ": " + gnutls_strerror(rc));
949 }
950 }
951
952 #if GNUTLS_VERSION_NUMBER >= 0x030600
953 rc = gnutls_certificate_set_known_dh_params(d_creds.get(), GNUTLS_SEC_PARAM_HIGH);
954 if (rc != GNUTLS_E_SUCCESS) {
955 throw std::runtime_error("Error setting DH params for TLS context on " + fe.d_addr.toStringWithPort() + ": " + gnutls_strerror(rc));
956 }
957 #endif
958
959 rc = gnutls_priority_init(&d_priorityCache, fe.d_ciphers.empty() ? "NORMAL" : fe.d_ciphers.c_str(), nullptr);
960 if (rc != GNUTLS_E_SUCCESS) {
961 warnlog("Error setting up TLS cipher preferences to %s (%s), skipping.", fe.d_ciphers.c_str(), gnutls_strerror(rc));
962 }
963
964 pthread_rwlock_init(&d_lock, nullptr);
965
966 try {
967 if (fe.d_ticketKeyFile.empty()) {
968 handleTicketsKeyRotation(time(nullptr));
969 }
970 else {
971 loadTicketsKeys(fe.d_ticketKeyFile);
972 }
973 }
974 catch(const std::runtime_error& e) {
975 pthread_rwlock_destroy(&d_lock);
976 throw std::runtime_error("Error generating tickets key for TLS context on " + fe.d_addr.toStringWithPort() + ": " + e.what());
977 }
978 }
979
980 virtual ~GnuTLSIOCtx() override
981 {
982 pthread_rwlock_destroy(&d_lock);
983
984 d_creds.reset();
985
986 if (d_priorityCache) {
987 gnutls_priority_deinit(d_priorityCache);
988 }
989 }
990
991 std::unique_ptr<TLSConnection> getConnection(int socket, unsigned int timeout, time_t now) override
992 {
993 handleTicketsKeyRotation(now);
994
995 std::shared_ptr<GnuTLSTicketsKey> ticketsKey;
996 {
997 ReadLock rl(&d_lock);
998 ticketsKey = d_ticketsKey;
999 }
1000
1001 return std::unique_ptr<GnuTLSConnection>(new GnuTLSConnection(socket, timeout, d_creds.get(), d_priorityCache, ticketsKey, d_enableTickets));
1002 }
1003
1004 void rotateTicketsKey(time_t now) override
1005 {
1006 if (!d_enableTickets) {
1007 return;
1008 }
1009
1010 auto newKey = std::make_shared<GnuTLSTicketsKey>();
1011
1012 {
1013 WriteLock wl(&d_lock);
1014 d_ticketsKey = newKey;
1015 }
1016
1017 if (d_ticketsKeyRotationDelay > 0) {
1018 d_ticketsKeyNextRotation = now + d_ticketsKeyRotationDelay;
1019 }
1020 }
1021
1022 void loadTicketsKeys(const std::string& file) override
1023 {
1024 if (!d_enableTickets) {
1025 return;
1026 }
1027
1028 auto newKey = std::make_shared<GnuTLSTicketsKey>(file);
1029 {
1030 WriteLock wl(&d_lock);
1031 d_ticketsKey = newKey;
1032 }
1033
1034 if (d_ticketsKeyRotationDelay > 0) {
1035 d_ticketsKeyNextRotation = time(nullptr) + d_ticketsKeyRotationDelay;
1036 }
1037 }
1038
1039 size_t getTicketsKeysCount() override
1040 {
1041 ReadLock rl(&d_lock);
1042 return d_ticketsKey != nullptr ? 1 : 0;
1043 }
1044
1045 private:
1046 std::unique_ptr<gnutls_certificate_credentials_st, void(*)(gnutls_certificate_credentials_t)> d_creds;
1047 gnutls_priority_t d_priorityCache{nullptr};
1048 std::shared_ptr<GnuTLSTicketsKey> d_ticketsKey{nullptr};
1049 pthread_rwlock_t d_lock;
1050 bool d_enableTickets{true};
1051 };
1052
1053 #endif /* HAVE_GNUTLS */
1054
1055 #endif /* HAVE_DNS_OVER_TLS */
1056
1057 bool TLSFrontend::setupTLS()
1058 {
1059 #ifdef HAVE_DNS_OVER_TLS
1060 /* get the "best" available provider */
1061 if (!d_provider.empty()) {
1062 #ifdef HAVE_GNUTLS
1063 if (d_provider == "gnutls") {
1064 d_ctx = std::make_shared<GnuTLSIOCtx>(*this);
1065 return true;
1066 }
1067 #endif /* HAVE_GNUTLS */
1068 #ifdef HAVE_LIBSSL
1069 if (d_provider == "openssl") {
1070 d_ctx = std::make_shared<OpenSSLTLSIOCtx>(*this);
1071 return true;
1072 }
1073 #endif /* HAVE_LIBSSL */
1074 }
1075 #ifdef HAVE_GNUTLS
1076 d_ctx = std::make_shared<GnuTLSIOCtx>(*this);
1077 #else /* HAVE_GNUTLS */
1078 #ifdef HAVE_LIBSSL
1079 d_ctx = std::make_shared<OpenSSLTLSIOCtx>(*this);
1080 #endif /* HAVE_LIBSSL */
1081 #endif /* HAVE_GNUTLS */
1082
1083 #endif /* HAVE_DNS_OVER_TLS */
1084 return true;
1085 }