]> git.ipfire.org Git - thirdparty/pdns.git/blob - pdns/dnsdistdist/tcpiohandler.cc
dnsdist: Add OCSP stapling (from files) for DoT and DoH
[thirdparty/pdns.git] / pdns / dnsdistdist / tcpiohandler.cc
1 #include <fstream>
2
3 #include "config.h"
4 #include "circular_buffer.hh"
5 #include "dolog.hh"
6 #include "iputils.hh"
7 #include "lock.hh"
8 #include "tcpiohandler.hh"
9
10 #ifdef HAVE_LIBSODIUM
11 #include <sodium.h>
12 #endif /* HAVE_LIBSODIUM */
13
14 #ifdef HAVE_DNS_OVER_TLS
15 #ifdef HAVE_LIBSSL
16 #include <openssl/conf.h>
17 #include <openssl/err.h>
18 #include <openssl/rand.h>
19 #include <openssl/ssl.h>
20
21 #include "libssl.hh"
22
23 /* From rfc5077 Section 4. Recommended Ticket Construction */
24 #define TLS_TICKETS_KEY_NAME_SIZE (16)
25
26 /* AES-256 */
27 #define TLS_TICKETS_CIPHER_KEY_SIZE (32)
28 #define TLS_TICKETS_CIPHER_ALGO (EVP_aes_256_cbc)
29
30 /* HMAC SHA-256 */
31 #define TLS_TICKETS_MAC_KEY_SIZE (32)
32 #define TLS_TICKETS_MAC_ALGO (EVP_sha256)
33
34 static int s_ticketsKeyIndex{-1};
35
36 class OpenSSLTLSTicketKey
37 {
38 public:
39 OpenSSLTLSTicketKey()
40 {
41 if (RAND_bytes(d_name, sizeof(d_name)) != 1) {
42 throw std::runtime_error("Error while generating the name of the OpenSSL TLS ticket key");
43 }
44
45 if (RAND_bytes(d_cipherKey, sizeof(d_cipherKey)) != 1) {
46 throw std::runtime_error("Error while generating the cipher key of the OpenSSL TLS ticket key");
47 }
48
49 if (RAND_bytes(d_hmacKey, sizeof(d_hmacKey)) != 1) {
50 throw std::runtime_error("Error while generating the HMAC key of the OpenSSL TLS ticket key");
51 }
52 #ifdef HAVE_LIBSODIUM
53 sodium_mlock(d_name, sizeof(d_name));
54 sodium_mlock(d_cipherKey, sizeof(d_cipherKey));
55 sodium_mlock(d_hmacKey, sizeof(d_hmacKey));
56 #endif /* HAVE_LIBSODIUM */
57 }
58
59 OpenSSLTLSTicketKey(ifstream& file)
60 {
61 file.read(reinterpret_cast<char*>(d_name), sizeof(d_name));
62 file.read(reinterpret_cast<char*>(d_cipherKey), sizeof(d_cipherKey));
63 file.read(reinterpret_cast<char*>(d_hmacKey), sizeof(d_hmacKey));
64
65 if (file.fail()) {
66 throw std::runtime_error("Unable to load a ticket key from the OpenSSL tickets key file");
67 }
68 #ifdef HAVE_LIBSODIUM
69 sodium_mlock(d_name, sizeof(d_name));
70 sodium_mlock(d_cipherKey, sizeof(d_cipherKey));
71 sodium_mlock(d_hmacKey, sizeof(d_hmacKey));
72 #endif /* HAVE_LIBSODIUM */
73 }
74
75 ~OpenSSLTLSTicketKey()
76 {
77 #ifdef HAVE_LIBSODIUM
78 sodium_munlock(d_name, sizeof(d_name));
79 sodium_munlock(d_cipherKey, sizeof(d_cipherKey));
80 sodium_munlock(d_hmacKey, sizeof(d_hmacKey));
81 #else
82 OPENSSL_cleanse(d_name, sizeof(d_name));
83 OPENSSL_cleanse(d_cipherKey, sizeof(d_cipherKey));
84 OPENSSL_cleanse(d_hmacKey, sizeof(d_hmacKey));
85 #endif /* HAVE_LIBSODIUM */
86 }
87
88 bool nameMatches(const unsigned char name[TLS_TICKETS_KEY_NAME_SIZE]) const
89 {
90 return (memcmp(d_name, name, sizeof(d_name)) == 0);
91 }
92
93 int encrypt(unsigned char keyName[TLS_TICKETS_KEY_NAME_SIZE], unsigned char *iv, EVP_CIPHER_CTX *ectx, HMAC_CTX *hctx) const
94 {
95 memcpy(keyName, d_name, sizeof(d_name));
96
97 if (RAND_bytes(iv, EVP_MAX_IV_LENGTH) != 1) {
98 return -1;
99 }
100
101 if (EVP_EncryptInit_ex(ectx, TLS_TICKETS_CIPHER_ALGO(), nullptr, d_cipherKey, iv) != 1) {
102 return -1;
103 }
104
105 if (HMAC_Init_ex(hctx, d_hmacKey, sizeof(d_hmacKey), TLS_TICKETS_MAC_ALGO(), nullptr) != 1) {
106 return -1;
107 }
108
109 return 1;
110 }
111
112 bool decrypt(const unsigned char* iv, EVP_CIPHER_CTX *ectx, HMAC_CTX *hctx) const
113 {
114 if (HMAC_Init_ex(hctx, d_hmacKey, sizeof(d_hmacKey), TLS_TICKETS_MAC_ALGO(), nullptr) != 1) {
115 return false;
116 }
117
118 if (EVP_DecryptInit_ex(ectx, TLS_TICKETS_CIPHER_ALGO(), nullptr, d_cipherKey, iv) != 1) {
119 return false;
120 }
121
122 return true;
123 }
124
125 private:
126 unsigned char d_name[TLS_TICKETS_KEY_NAME_SIZE];
127 unsigned char d_cipherKey[TLS_TICKETS_CIPHER_KEY_SIZE];
128 unsigned char d_hmacKey[TLS_TICKETS_MAC_KEY_SIZE];
129 };
130
131 class OpenSSLTLSTicketKeysRing
132 {
133 public:
134 OpenSSLTLSTicketKeysRing(size_t capacity)
135 {
136 pthread_rwlock_init(&d_lock, nullptr);
137 d_ticketKeys.set_capacity(capacity);
138 }
139
140 ~OpenSSLTLSTicketKeysRing()
141 {
142 pthread_rwlock_destroy(&d_lock);
143 }
144
145 void addKey(std::shared_ptr<OpenSSLTLSTicketKey> newKey)
146 {
147 WriteLock wl(&d_lock);
148 d_ticketKeys.push_back(newKey);
149 }
150
151 std::shared_ptr<OpenSSLTLSTicketKey> getEncryptionKey()
152 {
153 ReadLock rl(&d_lock);
154 return d_ticketKeys.front();
155 }
156
157 std::shared_ptr<OpenSSLTLSTicketKey> getDecryptionKey(unsigned char name[TLS_TICKETS_KEY_NAME_SIZE], bool& activeKey)
158 {
159 ReadLock rl(&d_lock);
160 for (auto& key : d_ticketKeys) {
161 if (key->nameMatches(name)) {
162 activeKey = (key == d_ticketKeys.front());
163 return key;
164 }
165 }
166 return nullptr;
167 }
168
169 size_t getKeysCount()
170 {
171 ReadLock rl(&d_lock);
172 return d_ticketKeys.size();
173 }
174
175 private:
176 boost::circular_buffer<std::shared_ptr<OpenSSLTLSTicketKey> > d_ticketKeys;
177 pthread_rwlock_t d_lock;
178 };
179
180 class OpenSSLTLSConnection: public TLSConnection
181 {
182 public:
183 OpenSSLTLSConnection(int socket, unsigned int timeout, SSL_CTX* tlsCtx): d_conn(std::unique_ptr<SSL, void(*)(SSL*)>(SSL_new(tlsCtx), SSL_free)), d_timeout(timeout)
184 {
185 d_socket = socket;
186
187 if (!d_conn) {
188 vinfolog("Error creating TLS object");
189 if (g_verbose) {
190 ERR_print_errors_fp(stderr);
191 }
192 throw std::runtime_error("Error creating TLS object");
193 }
194
195 if (!SSL_set_fd(d_conn.get(), d_socket)) {
196 throw std::runtime_error("Error assigning socket");
197 }
198 }
199
200 IOState convertIORequestToIOState(int res) const
201 {
202 int error = SSL_get_error(d_conn.get(), res);
203 if (error == SSL_ERROR_WANT_READ) {
204 return IOState::NeedRead;
205 }
206 else if (error == SSL_ERROR_WANT_WRITE) {
207 return IOState::NeedWrite;
208 }
209 else if (error == SSL_ERROR_SYSCALL) {
210 throw std::runtime_error("Error while processing TLS connection: " + std::string(strerror(errno)));
211 }
212 else {
213 throw std::runtime_error("Error while processing TLS connection: " + std::to_string(error));
214 }
215 }
216
217 void handleIORequest(int res, unsigned int timeout)
218 {
219 auto state = convertIORequestToIOState(res);
220 if (state == IOState::NeedRead) {
221 res = waitForData(d_socket, timeout);
222 if (res == 0) {
223 throw std::runtime_error("Timeout while reading from TLS connection");
224 }
225 else if (res < 0) {
226 throw std::runtime_error("Error waiting to read from TLS connection");
227 }
228 }
229 else if (state == IOState::NeedWrite) {
230 res = waitForRWData(d_socket, false, timeout, 0);
231 if (res == 0) {
232 throw std::runtime_error("Timeout while writing to TLS connection");
233 }
234 else if (res < 0) {
235 throw std::runtime_error("Error waiting to write to TLS connection");
236 }
237 }
238 }
239
240 IOState tryHandshake() override
241 {
242 int res = SSL_accept(d_conn.get());
243 if (res == 1) {
244 return IOState::Done;
245 }
246 else if (res < 0) {
247 return convertIORequestToIOState(res);
248 }
249
250 throw std::runtime_error("Error accepting TLS connection");
251 }
252
253 void doHandshake() override
254 {
255 int res = 0;
256 do {
257 res = SSL_accept(d_conn.get());
258 if (res < 0) {
259 handleIORequest(res, d_timeout);
260 }
261 }
262 while (res < 0);
263
264 if (res != 1) {
265 throw std::runtime_error("Error accepting TLS connection");
266 }
267 }
268
269 IOState tryWrite(std::vector<uint8_t>& buffer, size_t& pos, size_t toWrite) override
270 {
271 do {
272 int res = SSL_write(d_conn.get(), reinterpret_cast<const char *>(&buffer.at(pos)), static_cast<int>(toWrite - pos));
273 if (res <= 0) {
274 return convertIORequestToIOState(res);
275 }
276 else {
277 pos += static_cast<size_t>(res);
278 }
279 }
280 while (pos < toWrite);
281 return IOState::Done;
282 }
283
284 IOState tryRead(std::vector<uint8_t>& buffer, size_t& pos, size_t toRead) override
285 {
286 do {
287 int res = SSL_read(d_conn.get(), reinterpret_cast<char *>(&buffer.at(pos)), static_cast<int>(toRead - pos));
288 if (res <= 0) {
289 return convertIORequestToIOState(res);
290 }
291 else {
292 pos += static_cast<size_t>(res);
293 }
294 }
295 while (pos < toRead);
296 return IOState::Done;
297 }
298
299 size_t read(void* buffer, size_t bufferSize, unsigned int readTimeout, unsigned int totalTimeout) override
300 {
301 size_t got = 0;
302 time_t start = 0;
303 unsigned int remainingTime = totalTimeout;
304 if (totalTimeout) {
305 start = time(nullptr);
306 }
307
308 do {
309 int res = SSL_read(d_conn.get(), (reinterpret_cast<char *>(buffer) + got), static_cast<int>(bufferSize - got));
310 if (res <= 0) {
311 handleIORequest(res, readTimeout);
312 }
313 else {
314 got += static_cast<size_t>(res);
315 }
316
317 if (totalTimeout) {
318 time_t now = time(nullptr);
319 unsigned int elapsed = now - start;
320 if (now < start || elapsed >= remainingTime) {
321 throw runtime_error("Timeout while reading data");
322 }
323 start = now;
324 remainingTime -= elapsed;
325 }
326 }
327 while (got < bufferSize);
328
329 return got;
330 }
331
332 size_t write(const void* buffer, size_t bufferSize, unsigned int writeTimeout) override
333 {
334 size_t got = 0;
335 do {
336 int res = SSL_write(d_conn.get(), (reinterpret_cast<const char *>(buffer) + got), static_cast<int>(bufferSize - got));
337 if (res <= 0) {
338 handleIORequest(res, writeTimeout);
339 }
340 else {
341 got += static_cast<size_t>(res);
342 }
343 }
344 while (got < bufferSize);
345
346 return got;
347 }
348
349 void close() override
350 {
351 if (d_conn) {
352 SSL_shutdown(d_conn.get());
353 }
354 }
355
356 std::string getServerNameIndication() override
357 {
358 if (d_conn) {
359 const char* value = SSL_get_servername(d_conn.get(), TLSEXT_NAMETYPE_host_name);
360 if (value) {
361 return std::string(value);
362 }
363 }
364 return std::string();
365 }
366
367 private:
368 std::unique_ptr<SSL, void(*)(SSL*)> d_conn;
369 unsigned int d_timeout;
370 };
371
372 class OpenSSLTLSIOCtx: public TLSCtx
373 {
374 public:
375 OpenSSLTLSIOCtx(const TLSFrontend& fe): d_ticketKeys(fe.d_numberOfTicketsKeys), d_tlsCtx(std::unique_ptr<SSL_CTX, void(*)(SSL_CTX*)>(nullptr, SSL_CTX_free))
376 {
377 d_ticketsKeyRotationDelay = fe.d_ticketsKeyRotationDelay;
378
379 int sslOptions =
380 SSL_OP_NO_SSLv2 |
381 SSL_OP_NO_SSLv3 |
382 SSL_OP_NO_COMPRESSION |
383 SSL_OP_NO_SESSION_RESUMPTION_ON_RENEGOTIATION |
384 SSL_OP_SINGLE_DH_USE |
385 SSL_OP_SINGLE_ECDH_USE |
386 SSL_OP_CIPHER_SERVER_PREFERENCE;
387
388 if (!fe.d_enableTickets) {
389 sslOptions |= SSL_OP_NO_TICKET;
390 }
391
392 if (s_users.fetch_add(1) == 0) {
393 registerOpenSSLUser();
394
395 s_ticketsKeyIndex = SSL_CTX_get_ex_new_index(0, nullptr, nullptr, nullptr, nullptr);
396
397 if (s_ticketsKeyIndex == -1) {
398 throw std::runtime_error("Error getting an index for tickets key");
399 }
400 }
401
402 d_tlsCtx = std::unique_ptr<SSL_CTX, void(*)(SSL_CTX*)>(SSL_CTX_new(SSLv23_server_method()), SSL_CTX_free);
403 if (!d_tlsCtx) {
404 ERR_print_errors_fp(stderr);
405 throw std::runtime_error("Error creating TLS context on " + fe.d_addr.toStringWithPort());
406 }
407
408 /* use our own ticket keys handler so we can rotate them */
409 SSL_CTX_set_tlsext_ticket_key_cb(d_tlsCtx.get(), &OpenSSLTLSIOCtx::ticketKeyCb);
410 SSL_CTX_set_ex_data(d_tlsCtx.get(), s_ticketsKeyIndex, this);
411 SSL_CTX_set_options(d_tlsCtx.get(), sslOptions);
412 #if defined(SSL_CTX_set_ecdh_auto)
413 SSL_CTX_set_ecdh_auto(d_tlsCtx.get(), 1);
414 #endif
415 if (fe.d_maxStoredSessions == 0) {
416 /* disable stored sessions entirely */
417 SSL_CTX_set_session_cache_mode(d_tlsCtx.get(), SSL_SESS_CACHE_OFF);
418 }
419 else {
420 /* use the internal built-in cache to store sessions */
421 SSL_CTX_set_session_cache_mode(d_tlsCtx.get(), SSL_SESS_CACHE_SERVER);
422 SSL_CTX_sess_set_cache_size(d_tlsCtx.get(), fe.d_maxStoredSessions);
423 }
424
425 std::vector<int> keyTypes;
426 for (const auto& pair : fe.d_certKeyPairs) {
427 if (SSL_CTX_use_certificate_chain_file(d_tlsCtx.get(), pair.first.c_str()) != 1) {
428 ERR_print_errors_fp(stderr);
429 throw std::runtime_error("Error loading certificate from " + pair.first + " for the TLS context on " + fe.d_addr.toStringWithPort());
430 }
431 if (SSL_CTX_use_PrivateKey_file(d_tlsCtx.get(), pair.second.c_str(), SSL_FILETYPE_PEM) != 1) {
432 ERR_print_errors_fp(stderr);
433 throw std::runtime_error("Error loading key from " + pair.second + " for the TLS context on " + fe.d_addr.toStringWithPort());
434 }
435 if (SSL_CTX_check_private_key(d_tlsCtx.get()) != 1) {
436 ERR_print_errors_fp(stderr);
437 throw std::runtime_error("Key from '" + pair.second + "' does not match the certificate from '" + pair.first + "' for the TLS context on " + fe.d_addr.toStringWithPort());
438 }
439
440 /* store the type of the new key, we might need it later to select the right OCSP stapling response */
441 keyTypes.push_back(libssl_get_last_key_type(d_tlsCtx));
442 }
443
444 if (!fe.d_ocspFiles.empty()) {
445 try {
446 d_ocspResponses = libssl_load_ocsp_responses(fe.d_ocspFiles, keyTypes);
447 }
448 catch(const std::exception& e) {
449 throw std::runtime_error("Error loading responses for the TLS context on " + fe.d_addr.toStringWithPort() + ": " + e.what());
450 }
451
452 SSL_CTX_set_tlsext_status_cb(d_tlsCtx.get(), &OpenSSLTLSIOCtx::ocspStaplingCb);
453 SSL_CTX_set_tlsext_status_arg(d_tlsCtx.get(), &d_ocspResponses);
454 }
455
456 if (!fe.d_ciphers.empty()) {
457 if (SSL_CTX_set_cipher_list(d_tlsCtx.get(), fe.d_ciphers.c_str()) != 1) {
458 ERR_print_errors_fp(stderr);
459 throw std::runtime_error("Error setting the cipher list to '" + fe.d_ciphers + "' for the TLS context on " + fe.d_addr.toStringWithPort());
460 }
461 }
462
463 #ifdef HAVE_SSL_CTX_SET_CIPHERSUITES
464 if (!fe.d_ciphers13.empty()) {
465 if (SSL_CTX_set_ciphersuites(d_tlsCtx.get(), fe.d_ciphers13.c_str()) != 1) {
466 ERR_print_errors_fp(stderr);
467 throw std::runtime_error("Error setting the TLS 1.3 cipher list to '" + fe.d_ciphers13 + "' for the TLS context on " + fe.d_addr.toStringWithPort());
468 }
469 }
470 #endif /* HAVE_SSL_CTX_SET_CIPHERSUITES */
471
472 try {
473 if (fe.d_ticketKeyFile.empty()) {
474 handleTicketsKeyRotation(time(nullptr));
475 }
476 else {
477 loadTicketsKeys(fe.d_ticketKeyFile);
478 }
479 }
480 catch (const std::exception& e) {
481 throw;
482 }
483 }
484
485 virtual ~OpenSSLTLSIOCtx() override
486 {
487 d_tlsCtx.reset();
488
489 if (s_users.fetch_sub(1) == 1) {
490 unregisterOpenSSLUser();
491 }
492 }
493
494 static int ticketKeyCb(SSL *s, unsigned char keyName[TLS_TICKETS_KEY_NAME_SIZE], unsigned char *iv, EVP_CIPHER_CTX *ectx, HMAC_CTX *hctx, int enc)
495 {
496 SSL_CTX* sslCtx = SSL_get_SSL_CTX(s);
497 if (sslCtx == nullptr) {
498 return -1;
499 }
500
501 OpenSSLTLSIOCtx* ctx = reinterpret_cast<OpenSSLTLSIOCtx*>(SSL_CTX_get_ex_data(sslCtx, s_ticketsKeyIndex));
502 if (ctx == nullptr) {
503 return -1;
504 }
505
506 if (enc) {
507 const auto key = ctx->d_ticketKeys.getEncryptionKey();
508 if (key == nullptr) {
509 return -1;
510 }
511
512 return key->encrypt(keyName, iv, ectx, hctx);
513 }
514
515 bool activeEncryptionKey = false;
516
517 const auto key = ctx->d_ticketKeys.getDecryptionKey(keyName, activeEncryptionKey);
518 if (key == nullptr) {
519 /* we don't know this key, just create a new ticket */
520 return 0;
521 }
522
523 if (key->decrypt(iv, ectx, hctx) == false) {
524 return -1;
525 }
526
527 if (!activeEncryptionKey) {
528 /* this key is not active, please encrypt the ticket content with the currently active one */
529 return 2;
530 }
531
532 return 1;
533 }
534
535 static int ocspStaplingCb(SSL* ssl, void* arg)
536 {
537 if (ssl == nullptr || arg == nullptr) {
538 return SSL_TLSEXT_ERR_NOACK;
539 }
540 const auto ocspMap = reinterpret_cast<std::map<int, std::string>*>(arg);
541 return libssl_ocsp_stapling_callback(ssl, *ocspMap);
542 }
543
544 std::unique_ptr<TLSConnection> getConnection(int socket, unsigned int timeout, time_t now) override
545 {
546 handleTicketsKeyRotation(now);
547
548 return std::unique_ptr<OpenSSLTLSConnection>(new OpenSSLTLSConnection(socket, timeout, d_tlsCtx.get()));
549 }
550
551 void rotateTicketsKey(time_t now) override
552 {
553 auto newKey = std::make_shared<OpenSSLTLSTicketKey>();
554 d_ticketKeys.addKey(newKey);
555
556 if (d_ticketsKeyRotationDelay > 0) {
557 d_ticketsKeyNextRotation = now + d_ticketsKeyRotationDelay;
558 }
559 }
560
561 void loadTicketsKeys(const std::string& keyFile) override
562 {
563 bool keyLoaded = false;
564 ifstream file(keyFile);
565 try {
566 do {
567 auto newKey = std::make_shared<OpenSSLTLSTicketKey>(file);
568 d_ticketKeys.addKey(newKey);
569 keyLoaded = true;
570 }
571 while (!file.fail());
572 }
573 catch (const std::exception& e) {
574 /* if we haven't been able to load at least one key, fail */
575 if (!keyLoaded) {
576 throw;
577 }
578 }
579
580 if (d_ticketsKeyRotationDelay > 0) {
581 d_ticketsKeyNextRotation = time(nullptr) + d_ticketsKeyRotationDelay;
582 }
583
584 file.close();
585 }
586
587 size_t getTicketsKeysCount() override
588 {
589 return d_ticketKeys.getKeysCount();
590 }
591
592 private:
593 OpenSSLTLSTicketKeysRing d_ticketKeys;
594 std::map<int, std::string> d_ocspResponses;
595 std::unique_ptr<SSL_CTX, void(*)(SSL_CTX*)> d_tlsCtx;
596 static std::atomic<uint64_t> s_users;
597 };
598
599 std::atomic<uint64_t> OpenSSLTLSIOCtx::s_users(0);
600
601 #endif /* HAVE_LIBSSL */
602
603 #ifdef HAVE_GNUTLS
604 #include <gnutls/gnutls.h>
605 #include <gnutls/x509.h>
606
607 void safe_memory_lock(void* data, size_t size)
608 {
609 #ifdef HAVE_LIBSODIUM
610 sodium_mlock(data, size);
611 #endif
612 }
613
614 void safe_memory_release(void* data, size_t size)
615 {
616 #ifdef HAVE_LIBSODIUM
617 sodium_munlock(data, size);
618 #elif defined(HAVE_EXPLICIT_BZERO)
619 explicit_bzero(data, size);
620 #elif defined(HAVE_EXPLICIT_MEMSET)
621 explicit_memset(data, 0, size);
622 #elif defined(HAVE_GNUTLS_MEMSET)
623 gnutls_memset(data, 0, size);
624 #else
625 /* shamelessly taken from Dovecot's src/lib/safe-memset.c */
626 volatile unsigned int volatile_zero_idx = 0;
627 volatile unsigned char *p = reinterpret_cast<volatile unsigned char *>(data);
628
629 if (size == 0)
630 return;
631
632 do {
633 memset(data, 0, size);
634 } while (p[volatile_zero_idx] != 0);
635 #endif
636 }
637
638 class GnuTLSTicketsKey
639 {
640 public:
641 GnuTLSTicketsKey()
642 {
643 if (gnutls_session_ticket_key_generate(&d_key) != GNUTLS_E_SUCCESS) {
644 throw std::runtime_error("Error generating tickets key for TLS context");
645 }
646
647 safe_memory_lock(d_key.data, d_key.size);
648 }
649
650 GnuTLSTicketsKey(const std::string& keyFile)
651 {
652 /* to be sure we are loading the correct amount of data, which
653 may change between versions, let's generate a correct key first */
654 if (gnutls_session_ticket_key_generate(&d_key) != GNUTLS_E_SUCCESS) {
655 throw std::runtime_error("Error generating tickets key (before parsing key file) for TLS context");
656 }
657
658 safe_memory_lock(d_key.data, d_key.size);
659
660 try {
661 ifstream file(keyFile);
662 file.read(reinterpret_cast<char*>(d_key.data), d_key.size);
663
664 if (file.fail()) {
665 file.close();
666 throw std::runtime_error("Invalid GnuTLS tickets key file " + keyFile);
667 }
668
669 file.close();
670 }
671 catch (const std::exception& e) {
672 safe_memory_release(d_key.data, d_key.size);
673 gnutls_free(d_key.data);
674 d_key.data = nullptr;
675 throw;
676 }
677 }
678
679 ~GnuTLSTicketsKey()
680 {
681 if (d_key.data != nullptr && d_key.size > 0) {
682 safe_memory_release(d_key.data, d_key.size);
683 }
684 gnutls_free(d_key.data);
685 d_key.data = nullptr;
686 }
687 const gnutls_datum_t& getKey() const
688 {
689 return d_key;
690 }
691
692 private:
693 gnutls_datum_t d_key{nullptr, 0};
694 };
695
696 class GnuTLSConnection: public TLSConnection
697 {
698 public:
699
700 GnuTLSConnection(int socket, unsigned int timeout, const gnutls_certificate_credentials_t creds, const gnutls_priority_t priorityCache, std::shared_ptr<GnuTLSTicketsKey>& ticketsKey, bool enableTickets): d_conn(std::unique_ptr<gnutls_session_int, void(*)(gnutls_session_t)>(nullptr, gnutls_deinit)), d_ticketsKey(ticketsKey)
701 {
702 unsigned int sslOptions = GNUTLS_SERVER | GNUTLS_NONBLOCK;
703 #ifdef GNUTLS_NO_SIGNAL
704 sslOptions |= GNUTLS_NO_SIGNAL;
705 #endif
706
707 d_socket = socket;
708
709 gnutls_session_t conn;
710 if (gnutls_init(&conn, sslOptions) != GNUTLS_E_SUCCESS) {
711 throw std::runtime_error("Error creating TLS connection");
712 }
713
714 d_conn = std::unique_ptr<gnutls_session_int, void(*)(gnutls_session_t)>(conn, gnutls_deinit);
715 conn = nullptr;
716
717 if (gnutls_credentials_set(d_conn.get(), GNUTLS_CRD_CERTIFICATE, creds) != GNUTLS_E_SUCCESS) {
718 throw std::runtime_error("Error setting certificate and key to TLS connection");
719 }
720
721 if (gnutls_priority_set(d_conn.get(), priorityCache) != GNUTLS_E_SUCCESS) {
722 throw std::runtime_error("Error setting ciphers to TLS connection");
723 }
724
725 if (enableTickets && d_ticketsKey) {
726 const gnutls_datum_t& key = d_ticketsKey->getKey();
727 if (gnutls_session_ticket_enable_server(d_conn.get(), &key) != GNUTLS_E_SUCCESS) {
728 throw std::runtime_error("Error setting the tickets key to TLS connection");
729 }
730 }
731
732 gnutls_transport_set_int(d_conn.get(), d_socket);
733
734 /* timeouts are in milliseconds */
735 gnutls_handshake_set_timeout(d_conn.get(), timeout * 1000);
736 gnutls_record_set_timeout(d_conn.get(), timeout * 1000);
737 }
738
739 void doHandshake() override
740 {
741 int ret = 0;
742 do {
743 ret = gnutls_handshake(d_conn.get());
744 if (gnutls_error_is_fatal(ret) || ret == GNUTLS_E_WARNING_ALERT_RECEIVED) {
745 throw std::runtime_error("Error accepting a new connection");
746 }
747 }
748 while (ret < 0 && ret == GNUTLS_E_INTERRUPTED);
749 }
750
751 IOState tryHandshake() override
752 {
753 int ret = 0;
754
755 do {
756 ret = gnutls_handshake(d_conn.get());
757 if (ret == GNUTLS_E_SUCCESS) {
758 return IOState::Done;
759 }
760 else if (ret == GNUTLS_E_AGAIN) {
761 return IOState::NeedRead;
762 }
763 else if (gnutls_error_is_fatal(ret) || ret == GNUTLS_E_WARNING_ALERT_RECEIVED) {
764 throw std::runtime_error("Error accepting a new connection");
765 }
766 } while (ret == GNUTLS_E_INTERRUPTED);
767
768 throw std::runtime_error("Error accepting a new connection");
769 }
770
771 IOState tryWrite(std::vector<uint8_t>& buffer, size_t& pos, size_t toWrite) override
772 {
773 do {
774 ssize_t res = gnutls_record_send(d_conn.get(), reinterpret_cast<const char *>(&buffer.at(pos)), toWrite - pos);
775 if (res == 0) {
776 throw std::runtime_error("Error writing to TLS connection");
777 }
778 else if (res > 0) {
779 pos += static_cast<size_t>(res);
780 }
781 else if (res < 0) {
782 if (gnutls_error_is_fatal(res)) {
783 throw std::runtime_error("Fatal error writing to TLS connection: " + std::string(gnutls_strerror(res)));
784 }
785 else if (res == GNUTLS_E_AGAIN) {
786 return IOState::NeedWrite;
787 }
788 warnlog("Warning, non-fatal error while writing to TLS connection: %s", gnutls_strerror(res));
789 }
790 }
791 while (pos < toWrite);
792 return IOState::Done;
793 }
794
795 IOState tryRead(std::vector<uint8_t>& buffer, size_t& pos, size_t toRead) override
796 {
797 do {
798 ssize_t res = gnutls_record_recv(d_conn.get(), reinterpret_cast<char *>(&buffer.at(pos)), toRead - pos);
799 if (res == 0) {
800 throw std::runtime_error("Error reading from TLS connection");
801 }
802 else if (res > 0) {
803 pos += static_cast<size_t>(res);
804 }
805 else if (res < 0) {
806 if (gnutls_error_is_fatal(res)) {
807 throw std::runtime_error("Fatal error reading from TLS connection: " + std::string(gnutls_strerror(res)));
808 }
809 else if (res == GNUTLS_E_AGAIN) {
810 return IOState::NeedRead;
811 }
812 warnlog("Warning, non-fatal error while writing to TLS connection: %s", gnutls_strerror(res));
813 }
814 }
815 while (pos < toRead);
816 return IOState::Done;
817 }
818
819 size_t read(void* buffer, size_t bufferSize, unsigned int readTimeout, unsigned int totalTimeout) override
820 {
821 size_t got = 0;
822 time_t start = 0;
823 unsigned int remainingTime = totalTimeout;
824 if (totalTimeout) {
825 start = time(nullptr);
826 }
827
828 do {
829 ssize_t res = gnutls_record_recv(d_conn.get(), (reinterpret_cast<char *>(buffer) + got), bufferSize - got);
830 if (res == 0) {
831 throw std::runtime_error("Error reading from TLS connection");
832 }
833 else if (res > 0) {
834 got += static_cast<size_t>(res);
835 }
836 else if (res < 0) {
837 if (gnutls_error_is_fatal(res)) {
838 throw std::runtime_error("Fatal error reading from TLS connection: " + std::string(gnutls_strerror(res)));
839 }
840 else if (res == GNUTLS_E_AGAIN) {
841 int result = waitForData(d_socket, readTimeout);
842 if (result <= 0) {
843 throw std::runtime_error("Error while waiting to read from TLS connection: " + std::to_string(result));
844 }
845 }
846 else {
847 vinfolog("Non-fatal error while reading from TLS connection: %s", gnutls_strerror(res));
848 }
849 }
850
851 if (totalTimeout) {
852 time_t now = time(nullptr);
853 unsigned int elapsed = now - start;
854 if (now < start || elapsed >= remainingTime) {
855 throw runtime_error("Timeout while reading data");
856 }
857 start = now;
858 remainingTime -= elapsed;
859 }
860 }
861 while (got < bufferSize);
862
863 return got;
864 }
865
866 size_t write(const void* buffer, size_t bufferSize, unsigned int writeTimeout) override
867 {
868 size_t got = 0;
869
870 do {
871 ssize_t res = gnutls_record_send(d_conn.get(), (reinterpret_cast<const char *>(buffer) + got), bufferSize - got);
872 if (res == 0) {
873 throw std::runtime_error("Error writing to TLS connection");
874 }
875 else if (res > 0) {
876 got += static_cast<size_t>(res);
877 }
878 else if (res < 0) {
879 if (gnutls_error_is_fatal(res)) {
880 throw std::runtime_error("Fatal error writing to TLS connection: " + std::string(gnutls_strerror(res)));
881 }
882 else if (res == GNUTLS_E_AGAIN) {
883 int result = waitForRWData(d_socket, false, writeTimeout, 0);
884 if (result <= 0) {
885 throw std::runtime_error("Error waiting to write to TLS connection: " + std::to_string(result));
886 }
887 }
888 else {
889 vinfolog("Non-fatal error while writing to TLS connection: %s", gnutls_strerror(res));
890 }
891 }
892 }
893 while (got < bufferSize);
894
895 return got;
896 }
897
898 std::string getServerNameIndication() override
899 {
900 if (d_conn) {
901 unsigned int type;
902 size_t name_len = 256;
903 std::string sni;
904 sni.resize(name_len);
905
906 int res = gnutls_server_name_get(d_conn.get(), const_cast<char*>(sni.c_str()), &name_len, &type, 0);
907 if (res == GNUTLS_E_SUCCESS) {
908 sni.resize(name_len);
909 return sni;
910 }
911 }
912 return std::string();
913 }
914
915 void close() override
916 {
917 if (d_conn) {
918 gnutls_bye(d_conn.get(), GNUTLS_SHUT_WR);
919 }
920 }
921
922 private:
923 std::unique_ptr<gnutls_session_int, void(*)(gnutls_session_t)> d_conn;
924 std::shared_ptr<GnuTLSTicketsKey> d_ticketsKey;
925 };
926
927 class GnuTLSIOCtx: public TLSCtx
928 {
929 public:
930 GnuTLSIOCtx(const TLSFrontend& fe): d_creds(std::unique_ptr<gnutls_certificate_credentials_st, void(*)(gnutls_certificate_credentials_t)>(nullptr, gnutls_certificate_free_credentials)), d_enableTickets(fe.d_enableTickets)
931 {
932 int rc = 0;
933 d_ticketsKeyRotationDelay = fe.d_ticketsKeyRotationDelay;
934
935 gnutls_certificate_credentials_t creds;
936 rc = gnutls_certificate_allocate_credentials(&creds);
937 if (rc != GNUTLS_E_SUCCESS) {
938 throw std::runtime_error("Error allocating credentials for TLS context on " + fe.d_addr.toStringWithPort() + ": " + gnutls_strerror(rc));
939 }
940
941 d_creds = std::unique_ptr<gnutls_certificate_credentials_st, void(*)(gnutls_certificate_credentials_t)>(creds, gnutls_certificate_free_credentials);
942 creds = nullptr;
943
944 for (const auto& pair : fe.d_certKeyPairs) {
945 rc = gnutls_certificate_set_x509_key_file(d_creds.get(), pair.first.c_str(), pair.second.c_str(), GNUTLS_X509_FMT_PEM);
946 if (rc != GNUTLS_E_SUCCESS) {
947 throw std::runtime_error("Error loading certificate ('" + pair.first + "') and key ('" + pair.second + "') for TLS context on " + fe.d_addr.toStringWithPort() + ": " + gnutls_strerror(rc));
948 }
949 }
950
951 size_t count = 0;
952 for (const auto& file : fe.d_ocspFiles) {
953 rc = gnutls_certificate_set_ocsp_status_request_file(d_creds.get(), file.c_str(), count);
954 if (rc != GNUTLS_E_SUCCESS) {
955 throw std::runtime_error("Error loading OCSP response from file '" + file + "' for certificate ('" + fe.d_certKeyPairs.at(count).first + "') and key ('" + fe.d_certKeyPairs.at(count).second + "') for TLS context on " + fe.d_addr.toStringWithPort() + ": " + gnutls_strerror(rc));
956 }
957 ++count;
958 }
959
960 #if GNUTLS_VERSION_NUMBER >= 0x030600
961 rc = gnutls_certificate_set_known_dh_params(d_creds.get(), GNUTLS_SEC_PARAM_HIGH);
962 if (rc != GNUTLS_E_SUCCESS) {
963 throw std::runtime_error("Error setting DH params for TLS context on " + fe.d_addr.toStringWithPort() + ": " + gnutls_strerror(rc));
964 }
965 #endif
966
967 rc = gnutls_priority_init(&d_priorityCache, fe.d_ciphers.empty() ? "NORMAL" : fe.d_ciphers.c_str(), nullptr);
968 if (rc != GNUTLS_E_SUCCESS) {
969 throw std::runtime_error("Error setting up TLS cipher preferences to '" + fe.d_ciphers + "' (" + gnutls_strerror(rc) + ") on " + fe.d_addr.toStringWithPort());
970 }
971
972 pthread_rwlock_init(&d_lock, nullptr);
973
974 try {
975 if (fe.d_ticketKeyFile.empty()) {
976 handleTicketsKeyRotation(time(nullptr));
977 }
978 else {
979 loadTicketsKeys(fe.d_ticketKeyFile);
980 }
981 }
982 catch(const std::runtime_error& e) {
983 pthread_rwlock_destroy(&d_lock);
984 throw std::runtime_error("Error generating tickets key for TLS context on " + fe.d_addr.toStringWithPort() + ": " + e.what());
985 }
986 }
987
988 virtual ~GnuTLSIOCtx() override
989 {
990 pthread_rwlock_destroy(&d_lock);
991
992 d_creds.reset();
993
994 if (d_priorityCache) {
995 gnutls_priority_deinit(d_priorityCache);
996 }
997 }
998
999 std::unique_ptr<TLSConnection> getConnection(int socket, unsigned int timeout, time_t now) override
1000 {
1001 handleTicketsKeyRotation(now);
1002
1003 std::shared_ptr<GnuTLSTicketsKey> ticketsKey;
1004 {
1005 ReadLock rl(&d_lock);
1006 ticketsKey = d_ticketsKey;
1007 }
1008
1009 return std::unique_ptr<GnuTLSConnection>(new GnuTLSConnection(socket, timeout, d_creds.get(), d_priorityCache, ticketsKey, d_enableTickets));
1010 }
1011
1012 void rotateTicketsKey(time_t now) override
1013 {
1014 if (!d_enableTickets) {
1015 return;
1016 }
1017
1018 auto newKey = std::make_shared<GnuTLSTicketsKey>();
1019
1020 {
1021 WriteLock wl(&d_lock);
1022 d_ticketsKey = newKey;
1023 }
1024
1025 if (d_ticketsKeyRotationDelay > 0) {
1026 d_ticketsKeyNextRotation = now + d_ticketsKeyRotationDelay;
1027 }
1028 }
1029
1030 void loadTicketsKeys(const std::string& file) override
1031 {
1032 if (!d_enableTickets) {
1033 return;
1034 }
1035
1036 auto newKey = std::make_shared<GnuTLSTicketsKey>(file);
1037 {
1038 WriteLock wl(&d_lock);
1039 d_ticketsKey = newKey;
1040 }
1041
1042 if (d_ticketsKeyRotationDelay > 0) {
1043 d_ticketsKeyNextRotation = time(nullptr) + d_ticketsKeyRotationDelay;
1044 }
1045 }
1046
1047 size_t getTicketsKeysCount() override
1048 {
1049 ReadLock rl(&d_lock);
1050 return d_ticketsKey != nullptr ? 1 : 0;
1051 }
1052
1053 private:
1054 std::unique_ptr<gnutls_certificate_credentials_st, void(*)(gnutls_certificate_credentials_t)> d_creds;
1055 gnutls_priority_t d_priorityCache{nullptr};
1056 std::shared_ptr<GnuTLSTicketsKey> d_ticketsKey{nullptr};
1057 pthread_rwlock_t d_lock;
1058 bool d_enableTickets{true};
1059 };
1060
1061 #endif /* HAVE_GNUTLS */
1062
1063 #endif /* HAVE_DNS_OVER_TLS */
1064
1065 bool TLSFrontend::setupTLS()
1066 {
1067 #ifdef HAVE_DNS_OVER_TLS
1068 /* get the "best" available provider */
1069 if (!d_provider.empty()) {
1070 #ifdef HAVE_GNUTLS
1071 if (d_provider == "gnutls") {
1072 d_ctx = std::make_shared<GnuTLSIOCtx>(*this);
1073 return true;
1074 }
1075 #endif /* HAVE_GNUTLS */
1076 #ifdef HAVE_LIBSSL
1077 if (d_provider == "openssl") {
1078 d_ctx = std::make_shared<OpenSSLTLSIOCtx>(*this);
1079 return true;
1080 }
1081 #endif /* HAVE_LIBSSL */
1082 }
1083 #ifdef HAVE_GNUTLS
1084 d_ctx = std::make_shared<GnuTLSIOCtx>(*this);
1085 #else /* HAVE_GNUTLS */
1086 #ifdef HAVE_LIBSSL
1087 d_ctx = std::make_shared<OpenSSLTLSIOCtx>(*this);
1088 #endif /* HAVE_LIBSSL */
1089 #endif /* HAVE_GNUTLS */
1090
1091 #endif /* HAVE_DNS_OVER_TLS */
1092 return true;
1093 }