]> git.ipfire.org Git - thirdparty/pdns.git/blame - pdns/dnsdistdist/tcpiohandler.cc
dnsdist: Add 'ciphersTLS13' for DoT
[thirdparty/pdns.git] / pdns / dnsdistdist / tcpiohandler.cc
CommitLineData
a227f47d
RG
1#include <fstream>
2
3#include "config.h"
4#include "dolog.hh"
5#include "iputils.hh"
6#include "lock.hh"
7#include "tcpiohandler.hh"
8
9#ifdef HAVE_LIBSODIUM
10#include <sodium.h>
11#endif /* HAVE_LIBSODIUM */
12
13#ifdef HAVE_DNS_OVER_TLS
14#ifdef HAVE_LIBSSL
15#include <openssl/conf.h>
16#include <openssl/err.h>
17#include <openssl/rand.h>
18#include <openssl/ssl.h>
19
20#include <boost/circular_buffer.hpp>
21
ede152ec 22#include "libssl.hh"
a227f47d
RG
23
24/* From rfc5077 Section 4. Recommended Ticket Construction */
25#define TLS_TICKETS_KEY_NAME_SIZE (16)
26
27/* AES-256 */
28#define TLS_TICKETS_CIPHER_KEY_SIZE (32)
29#define TLS_TICKETS_CIPHER_ALGO (EVP_aes_256_cbc)
30
31/* HMAC SHA-256 */
32#define TLS_TICKETS_MAC_KEY_SIZE (32)
33#define TLS_TICKETS_MAC_ALGO (EVP_sha256)
34
35static int s_ticketsKeyIndex{-1};
36
37class OpenSSLTLSTicketKey
38{
39public:
40 OpenSSLTLSTicketKey()
41 {
42 if (RAND_bytes(d_name, sizeof(d_name)) != 1) {
43 throw std::runtime_error("Error while generating the name of the OpenSSL TLS ticket key");
44 }
45
46 if (RAND_bytes(d_cipherKey, sizeof(d_cipherKey)) != 1) {
47 throw std::runtime_error("Error while generating the cipher key of the OpenSSL TLS ticket key");
48 }
49
50 if (RAND_bytes(d_hmacKey, sizeof(d_hmacKey)) != 1) {
51 throw std::runtime_error("Error while generating the HMAC key of the OpenSSL TLS ticket key");
52 }
53#ifdef HAVE_LIBSODIUM
54 sodium_mlock(d_name, sizeof(d_name));
55 sodium_mlock(d_cipherKey, sizeof(d_cipherKey));
56 sodium_mlock(d_hmacKey, sizeof(d_hmacKey));
57#endif /* HAVE_LIBSODIUM */
58 }
59
60 OpenSSLTLSTicketKey(ifstream& file)
61 {
62 file.read(reinterpret_cast<char*>(d_name), sizeof(d_name));
63 file.read(reinterpret_cast<char*>(d_cipherKey), sizeof(d_cipherKey));
64 file.read(reinterpret_cast<char*>(d_hmacKey), sizeof(d_hmacKey));
65
66 if (file.fail()) {
67 throw std::runtime_error("Unable to load a ticket key from the OpenSSL tickets key file");
68 }
69#ifdef HAVE_LIBSODIUM
70 sodium_mlock(d_name, sizeof(d_name));
71 sodium_mlock(d_cipherKey, sizeof(d_cipherKey));
72 sodium_mlock(d_hmacKey, sizeof(d_hmacKey));
73#endif /* HAVE_LIBSODIUM */
74 }
75
76 ~OpenSSLTLSTicketKey()
77 {
78#ifdef HAVE_LIBSODIUM
79 sodium_munlock(d_name, sizeof(d_name));
80 sodium_munlock(d_cipherKey, sizeof(d_cipherKey));
81 sodium_munlock(d_hmacKey, sizeof(d_hmacKey));
82#else
83 OPENSSL_cleanse(d_name, sizeof(d_name));
84 OPENSSL_cleanse(d_cipherKey, sizeof(d_cipherKey));
85 OPENSSL_cleanse(d_hmacKey, sizeof(d_hmacKey));
86#endif /* HAVE_LIBSODIUM */
87 }
88
89 bool nameMatches(const unsigned char name[TLS_TICKETS_KEY_NAME_SIZE]) const
90 {
91 return (memcmp(d_name, name, sizeof(d_name)) == 0);
92 }
93
94 int encrypt(unsigned char keyName[TLS_TICKETS_KEY_NAME_SIZE], unsigned char *iv, EVP_CIPHER_CTX *ectx, HMAC_CTX *hctx) const
95 {
96 memcpy(keyName, d_name, sizeof(d_name));
97
98 if (RAND_bytes(iv, EVP_MAX_IV_LENGTH) != 1) {
99 return -1;
100 }
101
102 if (EVP_EncryptInit_ex(ectx, TLS_TICKETS_CIPHER_ALGO(), nullptr, d_cipherKey, iv) != 1) {
103 return -1;
104 }
105
106 if (HMAC_Init_ex(hctx, d_hmacKey, sizeof(d_hmacKey), TLS_TICKETS_MAC_ALGO(), nullptr) != 1) {
107 return -1;
108 }
109
110 return 1;
111 }
112
113 bool decrypt(const unsigned char* iv, EVP_CIPHER_CTX *ectx, HMAC_CTX *hctx) const
114 {
115 if (HMAC_Init_ex(hctx, d_hmacKey, sizeof(d_hmacKey), TLS_TICKETS_MAC_ALGO(), nullptr) != 1) {
116 return false;
117 }
118
119 if (EVP_DecryptInit_ex(ectx, TLS_TICKETS_CIPHER_ALGO(), nullptr, d_cipherKey, iv) != 1) {
120 return false;
121 }
122
123 return true;
124 }
125
126private:
127 unsigned char d_name[TLS_TICKETS_KEY_NAME_SIZE];
128 unsigned char d_cipherKey[TLS_TICKETS_CIPHER_KEY_SIZE];
129 unsigned char d_hmacKey[TLS_TICKETS_MAC_KEY_SIZE];
130};
131
132class OpenSSLTLSTicketKeysRing
133{
134public:
135 OpenSSLTLSTicketKeysRing(size_t capacity)
136 {
137 pthread_rwlock_init(&d_lock, nullptr);
138 d_ticketKeys.set_capacity(capacity);
139 }
140
141 ~OpenSSLTLSTicketKeysRing()
142 {
143 pthread_rwlock_destroy(&d_lock);
144 }
145
146 void addKey(std::shared_ptr<OpenSSLTLSTicketKey> newKey)
147 {
148 WriteLock wl(&d_lock);
149 d_ticketKeys.push_back(newKey);
150 }
151
152 std::shared_ptr<OpenSSLTLSTicketKey> getEncryptionKey()
153 {
154 ReadLock rl(&d_lock);
155 return d_ticketKeys.front();
156 }
157
158 std::shared_ptr<OpenSSLTLSTicketKey> getDecryptionKey(unsigned char name[TLS_TICKETS_KEY_NAME_SIZE], bool& activeKey)
159 {
160 ReadLock rl(&d_lock);
161 for (auto& key : d_ticketKeys) {
162 if (key->nameMatches(name)) {
163 activeKey = (key == d_ticketKeys.front());
164 return key;
165 }
166 }
167 return nullptr;
168 }
169
170 size_t getKeysCount()
171 {
172 ReadLock rl(&d_lock);
173 return d_ticketKeys.size();
174 }
175
176private:
177 boost::circular_buffer<std::shared_ptr<OpenSSLTLSTicketKey> > d_ticketKeys;
178 pthread_rwlock_t d_lock;
179};
180
181class OpenSSLTLSConnection: public TLSConnection
182{
183public:
d0ae6360 184 OpenSSLTLSConnection(int socket, unsigned int timeout, SSL_CTX* tlsCtx): d_conn(std::unique_ptr<SSL, void(*)(SSL*)>(SSL_new(tlsCtx), SSL_free)), d_timeout(timeout)
a227f47d
RG
185 {
186 d_socket = socket;
a227f47d
RG
187
188 if (!d_conn) {
189 vinfolog("Error creating TLS object");
190 if (g_verbose) {
191 ERR_print_errors_fp(stderr);
192 }
193 throw std::runtime_error("Error creating TLS object");
194 }
195
8dd7033b 196 if (!SSL_set_fd(d_conn.get(), d_socket)) {
a227f47d
RG
197 throw std::runtime_error("Error assigning socket");
198 }
d0ae6360
RG
199 }
200
201 IOState convertIORequestToIOState(int res) const
202 {
203 int error = SSL_get_error(d_conn.get(), res);
204 if (error == SSL_ERROR_WANT_READ) {
205 return IOState::NeedRead;
206 }
207 else if (error == SSL_ERROR_WANT_WRITE) {
208 return IOState::NeedWrite;
209 }
87cdc63a
RG
210 else if (error == SSL_ERROR_SYSCALL) {
211 throw std::runtime_error("Error while processing TLS connection:" + std::string(strerror(errno)));
212 }
d0ae6360
RG
213 else {
214 throw std::runtime_error("Error while processing TLS connection:" + std::to_string(error));
215 }
216 }
217
218 void handleIORequest(int res, unsigned int timeout)
219 {
220 auto state = convertIORequestToIOState(res);
221 if (state == IOState::NeedRead) {
222 res = waitForData(d_socket, timeout);
223 if (res <= 0) {
224 throw std::runtime_error("Error reading from TLS connection");
225 }
226 }
227 else if (state == IOState::NeedWrite) {
228 res = waitForRWData(d_socket, false, timeout, 0);
229 if (res <= 0) {
230 throw std::runtime_error("Error waiting to write to TLS connection");
231 }
232 }
233 }
234
3163698b 235 IOState tryHandshake() override
d0ae6360
RG
236 {
237 int res = SSL_accept(d_conn.get());
238 if (res == 1) {
239 return IOState::Done;
240 }
241 else if (res < 0) {
242 return convertIORequestToIOState(res);
243 }
244
245 throw std::runtime_error("Error accepting TLS connection");
246 }
a227f47d 247
3163698b 248 void doHandshake() override
d0ae6360 249 {
a227f47d
RG
250 int res = 0;
251 do {
8dd7033b 252 res = SSL_accept(d_conn.get());
a227f47d 253 if (res < 0) {
d0ae6360 254 handleIORequest(res, d_timeout);
a227f47d
RG
255 }
256 }
257 while (res < 0);
258
259 if (res != 1) {
260 throw std::runtime_error("Error accepting TLS connection");
261 }
262 }
263
d0ae6360 264 IOState tryWrite(std::vector<uint8_t>& buffer, size_t& pos, size_t toWrite) override
a227f47d 265 {
d0ae6360
RG
266 do {
267 int res = SSL_write(d_conn.get(), reinterpret_cast<const char *>(&buffer.at(pos)), static_cast<int>(toWrite - pos));
268 if (res == 0) {
269 throw std::runtime_error("Error writing to TLS connection");
a227f47d 270 }
d0ae6360
RG
271 else if (res < 0) {
272 return convertIORequestToIOState(res);
273 }
274 else {
275 pos += static_cast<size_t>(res);
a227f47d
RG
276 }
277 }
d0ae6360
RG
278 while (pos < toWrite);
279 return IOState::Done;
280 }
281
282 IOState tryRead(std::vector<uint8_t>& buffer, size_t& pos, size_t toRead) override
283 {
284 do {
285 int res = SSL_read(d_conn.get(), reinterpret_cast<char *>(&buffer.at(pos)), static_cast<int>(toRead - pos));
286 if (res == 0) {
287 throw std::runtime_error("Error reading from TLS connection");
288 }
289 else if (res < 0) {
290 return convertIORequestToIOState(res);
291 }
292 else {
293 pos += static_cast<size_t>(res);
294 }
a227f47d 295 }
d0ae6360
RG
296 while (pos < toRead);
297 return IOState::Done;
a227f47d
RG
298 }
299
300 size_t read(void* buffer, size_t bufferSize, unsigned int readTimeout, unsigned int totalTimeout) override
301 {
302 size_t got = 0;
303 time_t start = 0;
304 unsigned int remainingTime = totalTimeout;
305 if (totalTimeout) {
306 start = time(nullptr);
307 }
308
309 do {
8dd7033b 310 int res = SSL_read(d_conn.get(), (reinterpret_cast<char *>(buffer) + got), static_cast<int>(bufferSize - got));
a227f47d
RG
311 if (res == 0) {
312 throw std::runtime_error("Error reading from TLS connection");
313 }
314 else if (res < 0) {
315 handleIORequest(res, readTimeout);
316 }
317 else {
d0ae6360 318 got += static_cast<size_t>(res);
a227f47d
RG
319 }
320
321 if (totalTimeout) {
322 time_t now = time(nullptr);
323 unsigned int elapsed = now - start;
324 if (now < start || elapsed >= remainingTime) {
325 throw runtime_error("Timeout while reading data");
326 }
327 start = now;
328 remainingTime -= elapsed;
329 }
330 }
331 while (got < bufferSize);
332
333 return got;
334 }
335
336 size_t write(const void* buffer, size_t bufferSize, unsigned int writeTimeout) override
337 {
338 size_t got = 0;
339 do {
8dd7033b 340 int res = SSL_write(d_conn.get(), (reinterpret_cast<const char *>(buffer) + got), static_cast<int>(bufferSize - got));
a227f47d
RG
341 if (res == 0) {
342 throw std::runtime_error("Error writing to TLS connection");
343 }
344 else if (res < 0) {
345 handleIORequest(res, writeTimeout);
346 }
347 else {
d0ae6360 348 got += static_cast<size_t>(res);
a227f47d
RG
349 }
350 }
351 while (got < bufferSize);
352
353 return got;
354 }
355 void close() override
356 {
357 if (d_conn) {
8dd7033b 358 SSL_shutdown(d_conn.get());
a227f47d
RG
359 }
360 }
361
362private:
8dd7033b 363 std::unique_ptr<SSL, void(*)(SSL*)> d_conn;
d0ae6360 364 unsigned int d_timeout;
a227f47d
RG
365};
366
367class OpenSSLTLSIOCtx: public TLSCtx
368{
369public:
8dd7033b 370 OpenSSLTLSIOCtx(const TLSFrontend& fe): d_ticketKeys(fe.d_numberOfTicketsKeys), d_tlsCtx(std::unique_ptr<SSL_CTX, void(*)(SSL_CTX*)>(nullptr, SSL_CTX_free))
a227f47d
RG
371 {
372 d_ticketsKeyRotationDelay = fe.d_ticketsKeyRotationDelay;
373
c8d7b468 374 int sslOptions =
a227f47d
RG
375 SSL_OP_NO_SSLv2 |
376 SSL_OP_NO_SSLv3 |
377 SSL_OP_NO_COMPRESSION |
378 SSL_OP_NO_SESSION_RESUMPTION_ON_RENEGOTIATION |
379 SSL_OP_SINGLE_DH_USE |
380 SSL_OP_SINGLE_ECDH_USE |
381 SSL_OP_CIPHER_SERVER_PREFERENCE;
382
ba20dc97 383 if (!fe.d_enableTickets) {
c8d7b468
RG
384 sslOptions |= SSL_OP_NO_TICKET;
385 }
386
a227f47d 387 if (s_users.fetch_add(1) == 0) {
ede152ec 388 registerOpenSSLUser();
a227f47d 389
0ca5d025 390 s_ticketsKeyIndex = SSL_CTX_get_ex_new_index(0, nullptr, nullptr, nullptr, nullptr);
a227f47d
RG
391
392 if (s_ticketsKeyIndex == -1) {
393 throw std::runtime_error("Error getting an index for tickets key");
394 }
395 }
396
8dd7033b 397 d_tlsCtx = std::unique_ptr<SSL_CTX, void(*)(SSL_CTX*)>(SSL_CTX_new(SSLv23_server_method()), SSL_CTX_free);
a227f47d
RG
398 if (!d_tlsCtx) {
399 ERR_print_errors_fp(stderr);
400 throw std::runtime_error("Error creating TLS context on " + fe.d_addr.toStringWithPort());
401 }
402
a227f47d 403 /* use our own ticket keys handler so we can rotate them */
8dd7033b
RG
404 SSL_CTX_set_tlsext_ticket_key_cb(d_tlsCtx.get(), &OpenSSLTLSIOCtx::ticketKeyCb);
405 SSL_CTX_set_ex_data(d_tlsCtx.get(), s_ticketsKeyIndex, this);
406 SSL_CTX_set_options(d_tlsCtx.get(), sslOptions);
a227f47d 407#if defined(SSL_CTX_set_ecdh_auto)
8dd7033b 408 SSL_CTX_set_ecdh_auto(d_tlsCtx.get(), 1);
a227f47d 409#endif
d395c941
RG
410 if (fe.d_maxStoredSessions == 0) {
411 /* disable stored sessions entirely */
412 SSL_CTX_set_session_cache_mode(d_tlsCtx.get(), SSL_SESS_CACHE_OFF);
413 }
414 else {
415 /* use the internal built-in cache to store sessions */
416 SSL_CTX_set_session_cache_mode(d_tlsCtx.get(), SSL_SESS_CACHE_SERVER);
417 SSL_CTX_sess_set_cache_size(d_tlsCtx.get(), fe.d_maxStoredSessions);
418 }
fa974ada
RG
419
420 for (const auto& pair : fe.d_certKeyPairs) {
8dd7033b 421 if (SSL_CTX_use_certificate_chain_file(d_tlsCtx.get(), pair.first.c_str()) != 1) {
fa974ada
RG
422 ERR_print_errors_fp(stderr);
423 throw std::runtime_error("Error loading certificate from " + pair.first + " for the TLS context on " + fe.d_addr.toStringWithPort());
424 }
8dd7033b 425 if (SSL_CTX_use_PrivateKey_file(d_tlsCtx.get(), pair.second.c_str(), SSL_FILETYPE_PEM) != 1) {
fa974ada
RG
426 ERR_print_errors_fp(stderr);
427 throw std::runtime_error("Error loading key from " + pair.second + " for the TLS context on " + fe.d_addr.toStringWithPort());
428 }
429 }
a227f47d
RG
430
431 if (!fe.d_ciphers.empty()) {
8dd7033b 432 if (SSL_CTX_set_cipher_list(d_tlsCtx.get(), fe.d_ciphers.c_str()) != 1) {
fa974ada
RG
433 ERR_print_errors_fp(stderr);
434 throw std::runtime_error("Error setting the cipher list to '" + fe.d_ciphers + "' for the TLS context on " + fe.d_addr.toStringWithPort());
435 }
a227f47d
RG
436 }
437
9e67ac67
RG
438#ifdef HAVE_SSL_CTX_SET_CIPHERSUITES
439 if (!fe.d_ciphers13.empty()) {
440 if (SSL_CTX_set_ciphersuites(d_tlsCtx.get(), fe.d_ciphers13.c_str()) != 1) {
441 ERR_print_errors_fp(stderr);
442 throw std::runtime_error("Error setting the TLS 1.3 cipher list to '" + fe.d_ciphers13 + "' for the TLS context on " + fe.d_addr.toStringWithPort());
443 }
444 }
445#endif /* HAVE_SSL_CTX_SET_CIPHERSUITES */
446
a227f47d
RG
447 try {
448 if (fe.d_ticketKeyFile.empty()) {
449 handleTicketsKeyRotation(time(nullptr));
450 }
451 else {
452 loadTicketsKeys(fe.d_ticketKeyFile);
453 }
454 }
455 catch (const std::exception& e) {
a227f47d
RG
456 throw;
457 }
458 }
459
460 virtual ~OpenSSLTLSIOCtx() override
461 {
8dd7033b 462 d_tlsCtx.reset();
a227f47d
RG
463
464 if (s_users.fetch_sub(1) == 1) {
ede152ec 465 unregisterOpenSSLUser();
a227f47d
RG
466 }
467 }
468
469 static int ticketKeyCb(SSL *s, unsigned char keyName[TLS_TICKETS_KEY_NAME_SIZE], unsigned char *iv, EVP_CIPHER_CTX *ectx, HMAC_CTX *hctx, int enc)
470 {
471 SSL_CTX* sslCtx = SSL_get_SSL_CTX(s);
472 if (sslCtx == nullptr) {
473 return -1;
474 }
475
476 OpenSSLTLSIOCtx* ctx = reinterpret_cast<OpenSSLTLSIOCtx*>(SSL_CTX_get_ex_data(sslCtx, s_ticketsKeyIndex));
477 if (ctx == nullptr) {
478 return -1;
479 }
480
481 if (enc) {
482 const auto key = ctx->d_ticketKeys.getEncryptionKey();
483 if (key == nullptr) {
484 return -1;
485 }
486
487 return key->encrypt(keyName, iv, ectx, hctx);
488 }
489
490 bool activeEncryptionKey = false;
491
492 const auto key = ctx->d_ticketKeys.getDecryptionKey(keyName, activeEncryptionKey);
493 if (key == nullptr) {
494 /* we don't know this key, just create a new ticket */
495 return 0;
496 }
497
498 if (key->decrypt(iv, ectx, hctx) == false) {
499 return -1;
500 }
501
502 if (!activeEncryptionKey) {
503 /* this key is not active, please encrypt the ticket content with the currently active one */
504 return 2;
505 }
506
507 return 1;
508 }
509
510 std::unique_ptr<TLSConnection> getConnection(int socket, unsigned int timeout, time_t now) override
511 {
512 handleTicketsKeyRotation(now);
513
8dd7033b 514 return std::unique_ptr<OpenSSLTLSConnection>(new OpenSSLTLSConnection(socket, timeout, d_tlsCtx.get()));
a227f47d
RG
515 }
516
517 void rotateTicketsKey(time_t now) override
518 {
519 auto newKey = std::make_shared<OpenSSLTLSTicketKey>();
520 d_ticketKeys.addKey(newKey);
521
522 if (d_ticketsKeyRotationDelay > 0) {
2d29e6b7 523 d_ticketsKeyNextRotation = now + d_ticketsKeyRotationDelay;
a227f47d
RG
524 }
525 }
526
527 void loadTicketsKeys(const std::string& keyFile) override
528 {
529 bool keyLoaded = false;
530 ifstream file(keyFile);
531 try {
532 do {
533 auto newKey = std::make_shared<OpenSSLTLSTicketKey>(file);
534 d_ticketKeys.addKey(newKey);
535 keyLoaded = true;
536 }
537 while (!file.fail());
538 }
539 catch (const std::exception& e) {
540 /* if we haven't been able to load at least one key, fail */
541 if (!keyLoaded) {
542 throw;
543 }
544 }
545
546 if (d_ticketsKeyRotationDelay > 0) {
547 d_ticketsKeyNextRotation = time(nullptr) + d_ticketsKeyRotationDelay;
548 }
549
550 file.close();
551 }
552
553 size_t getTicketsKeysCount() override
554 {
555 return d_ticketKeys.getKeysCount();
556 }
557
558private:
559 OpenSSLTLSTicketKeysRing d_ticketKeys;
8dd7033b 560 std::unique_ptr<SSL_CTX, void(*)(SSL_CTX*)> d_tlsCtx;
a227f47d
RG
561 static std::atomic<uint64_t> s_users;
562};
563
564std::atomic<uint64_t> OpenSSLTLSIOCtx::s_users(0);
565
566#endif /* HAVE_LIBSSL */
567
568#ifdef HAVE_GNUTLS
569#include <gnutls/gnutls.h>
570#include <gnutls/x509.h>
571
7e81628b 572void safe_memory_lock(void* data, size_t size)
68aaaa06 573{
7e81628b
RG
574#ifdef HAVE_LIBSODIUM
575 sodium_mlock(data, size);
576#endif
577}
578
579void safe_memory_release(void* data, size_t size)
580{
581#ifdef HAVE_LIBSODIUM
582 sodium_munlock(data, size);
583#elif defined(HAVE_EXPLICIT_BZERO)
eca7079b
RG
584 explicit_bzero(data, size);
585#elif defined(HAVE_EXPLICIT_MEMSET)
e1a1d350 586 explicit_memset(data, 0, size);
eca7079b 587#elif defined(HAVE_GNUTLS_MEMSET)
e1a1d350 588 gnutls_memset(data, 0, size);
7e81628b 589#else
eca7079b
RG
590 /* shamelessly taken from Dovecot's src/lib/safe-memset.c */
591 volatile unsigned int volatile_zero_idx = 0;
592 volatile unsigned char *p = reinterpret_cast<volatile unsigned char *>(data);
593
594 if (size == 0)
595 return;
596
597 do {
e1a1d350
RG
598 memset(data, 0, size);
599 } while (p[volatile_zero_idx] != 0);
7e81628b 600#endif
68aaaa06 601}
68aaaa06 602
a227f47d
RG
603class GnuTLSTicketsKey
604{
605public:
606 GnuTLSTicketsKey()
607 {
608 if (gnutls_session_ticket_key_generate(&d_key) != GNUTLS_E_SUCCESS) {
609 throw std::runtime_error("Error generating tickets key for TLS context");
610 }
611
7e81628b 612 safe_memory_lock(d_key.data, d_key.size);
a227f47d
RG
613 }
614
615 GnuTLSTicketsKey(const std::string& keyFile)
616 {
617 /* to be sure we are loading the correct amount of data, which
618 may change between versions, let's generate a correct key first */
619 if (gnutls_session_ticket_key_generate(&d_key) != GNUTLS_E_SUCCESS) {
620 throw std::runtime_error("Error generating tickets key (before parsing key file) for TLS context");
621 }
622
7e81628b 623 safe_memory_lock(d_key.data, d_key.size);
a227f47d
RG
624
625 try {
626 ifstream file(keyFile);
627 file.read(reinterpret_cast<char*>(d_key.data), d_key.size);
628
629 if (file.fail()) {
630 file.close();
631 throw std::runtime_error("Invalid GnuTLS tickets key file " + keyFile);
632 }
633
634 file.close();
635 }
636 catch (const std::exception& e) {
7e81628b 637 safe_memory_release(d_key.data, d_key.size);
a227f47d 638 gnutls_free(d_key.data);
1f65c18c 639 d_key.data = nullptr;
a227f47d
RG
640 throw;
641 }
642 }
643
644 ~GnuTLSTicketsKey()
645 {
646 if (d_key.data != nullptr && d_key.size > 0) {
7e81628b 647 safe_memory_release(d_key.data, d_key.size);
a227f47d
RG
648 }
649 gnutls_free(d_key.data);
1f65c18c 650 d_key.data = nullptr;
a227f47d
RG
651 }
652 const gnutls_datum_t& getKey() const
653 {
654 return d_key;
655 }
656
657private:
658 gnutls_datum_t d_key{nullptr, 0};
659};
660
661class GnuTLSConnection: public TLSConnection
662{
663public:
664
8dd7033b 665 GnuTLSConnection(int socket, unsigned int timeout, const gnutls_certificate_credentials_t creds, const gnutls_priority_t priorityCache, std::shared_ptr<GnuTLSTicketsKey>& ticketsKey, bool enableTickets): d_conn(std::unique_ptr<gnutls_session_int, void(*)(gnutls_session_t)>(nullptr, gnutls_deinit)), d_ticketsKey(ticketsKey)
a227f47d 666 {
d0ae6360 667 unsigned int sslOptions = GNUTLS_SERVER | GNUTLS_NONBLOCK;
a227f47d 668#ifdef GNUTLS_NO_SIGNAL
c8d7b468 669 sslOptions |= GNUTLS_NO_SIGNAL;
a227f47d 670#endif
c8d7b468
RG
671
672 d_socket = socket;
673
8dd7033b
RG
674 gnutls_session_t conn;
675 if (gnutls_init(&conn, sslOptions) != GNUTLS_E_SUCCESS) {
a227f47d
RG
676 throw std::runtime_error("Error creating TLS connection");
677 }
678
8dd7033b
RG
679 d_conn = std::unique_ptr<gnutls_session_int, void(*)(gnutls_session_t)>(conn, gnutls_deinit);
680 conn = nullptr;
681
682 if (gnutls_credentials_set(d_conn.get(), GNUTLS_CRD_CERTIFICATE, creds) != GNUTLS_E_SUCCESS) {
a227f47d
RG
683 throw std::runtime_error("Error setting certificate and key to TLS connection");
684 }
685
8dd7033b 686 if (gnutls_priority_set(d_conn.get(), priorityCache) != GNUTLS_E_SUCCESS) {
a227f47d
RG
687 throw std::runtime_error("Error setting ciphers to TLS connection");
688 }
689
ba20dc97 690 if (enableTickets && d_ticketsKey) {
a227f47d 691 const gnutls_datum_t& key = d_ticketsKey->getKey();
8dd7033b 692 if (gnutls_session_ticket_enable_server(d_conn.get(), &key) != GNUTLS_E_SUCCESS) {
a227f47d
RG
693 throw std::runtime_error("Error setting the tickets key to TLS connection");
694 }
695 }
696
8dd7033b 697 gnutls_transport_set_int(d_conn.get(), d_socket);
a227f47d
RG
698
699 /* timeouts are in milliseconds */
8dd7033b
RG
700 gnutls_handshake_set_timeout(d_conn.get(), timeout * 1000);
701 gnutls_record_set_timeout(d_conn.get(), timeout * 1000);
d0ae6360 702 }
a227f47d 703
3163698b 704 void doHandshake() override
d0ae6360 705 {
a227f47d
RG
706 int ret = 0;
707 do {
8dd7033b 708 ret = gnutls_handshake(d_conn.get());
d0ae6360
RG
709 if (gnutls_error_is_fatal(ret) || ret == GNUTLS_E_WARNING_ALERT_RECEIVED) {
710 throw std::runtime_error("Error accepting a new connection");
711 }
712 }
713 while (ret < 0 && ret == GNUTLS_E_INTERRUPTED);
714 }
715
3163698b 716 IOState tryHandshake() override
d0ae6360
RG
717 {
718 int ret = 0;
719
720 do {
721 ret = gnutls_handshake(d_conn.get());
722 if (ret == GNUTLS_E_SUCCESS) {
723 return IOState::Done;
724 }
725 else if (ret == GNUTLS_E_AGAIN) {
726 return IOState::NeedRead;
727 }
728 else if (gnutls_error_is_fatal(ret) || ret == GNUTLS_E_WARNING_ALERT_RECEIVED) {
729 throw std::runtime_error("Error accepting a new connection");
730 }
731 } while (ret == GNUTLS_E_INTERRUPTED);
732
733 throw std::runtime_error("Error accepting a new connection");
734 }
735
736 IOState tryWrite(std::vector<uint8_t>& buffer, size_t& pos, size_t toWrite) override
737 {
738 do {
739 ssize_t res = gnutls_record_send(d_conn.get(), reinterpret_cast<const char *>(&buffer.at(pos)), toWrite - pos);
740 if (res == 0) {
741 throw std::runtime_error("Error writing to TLS connection");
742 }
743 else if (res > 0) {
744 pos += static_cast<size_t>(res);
745 }
746 else if (res < 0) {
747 if (gnutls_error_is_fatal(res)) {
748 throw std::runtime_error("Error writing to TLS connection");
749 }
750 else if (res == GNUTLS_E_AGAIN) {
751 return IOState::NeedWrite;
752 }
753 warnlog("Warning, non-fatal error while writing to TLS connection: %s", gnutls_strerror(res));
754 }
755 }
756 while (pos < toWrite);
757 return IOState::Done;
758 }
759
760 IOState tryRead(std::vector<uint8_t>& buffer, size_t& pos, size_t toRead) override
761 {
762 do {
763 ssize_t res = gnutls_record_recv(d_conn.get(), reinterpret_cast<char *>(&buffer.at(pos)), toRead - pos);
764 if (res == 0) {
765 throw std::runtime_error("Error reading from TLS connection");
766 }
767 else if (res > 0) {
768 pos += static_cast<size_t>(res);
769 }
770 else if (res < 0) {
771 if (gnutls_error_is_fatal(res)) {
772 throw std::runtime_error("Error reading from TLS connection");
773 }
774 else if (res == GNUTLS_E_AGAIN) {
775 return IOState::NeedRead;
776 }
777 warnlog("Warning, non-fatal error while writing to TLS connection: %s", gnutls_strerror(res));
778 }
a227f47d 779 }
d0ae6360
RG
780 while (pos < toRead);
781 return IOState::Done;
a227f47d
RG
782 }
783
a227f47d
RG
784 size_t read(void* buffer, size_t bufferSize, unsigned int readTimeout, unsigned int totalTimeout) override
785 {
786 size_t got = 0;
787 time_t start = 0;
788 unsigned int remainingTime = totalTimeout;
789 if (totalTimeout) {
790 start = time(nullptr);
791 }
792
793 do {
8dd7033b 794 ssize_t res = gnutls_record_recv(d_conn.get(), (reinterpret_cast<char *>(buffer) + got), bufferSize - got);
a227f47d
RG
795 if (res == 0) {
796 throw std::runtime_error("Error reading from TLS connection");
797 }
798 else if (res > 0) {
d0ae6360 799 got += static_cast<size_t>(res);
a227f47d
RG
800 }
801 else if (res < 0) {
802 if (gnutls_error_is_fatal(res)) {
d7aff6e6
RG
803 throw std::runtime_error("Error reading from TLS connection:" + std::string(gnutls_strerror(res)));
804 }
805 else if (res == GNUTLS_E_AGAIN) {
806 int result = waitForData(d_socket, readTimeout);
807 if (result <= 0) {
808 throw std::runtime_error("Error reading from TLS connection: " + std::to_string(result));
809 }
810 }
811 else {
812 vinfolog("Non-fatal error while reading from TLS connection: %s", gnutls_strerror(res));
a227f47d 813 }
a227f47d
RG
814 }
815
816 if (totalTimeout) {
817 time_t now = time(nullptr);
818 unsigned int elapsed = now - start;
819 if (now < start || elapsed >= remainingTime) {
820 throw runtime_error("Timeout while reading data");
821 }
822 start = now;
823 remainingTime -= elapsed;
824 }
825 }
826 while (got < bufferSize);
827
828 return got;
829 }
830
831 size_t write(const void* buffer, size_t bufferSize, unsigned int writeTimeout) override
832 {
833 size_t got = 0;
834
835 do {
8dd7033b 836 ssize_t res = gnutls_record_send(d_conn.get(), (reinterpret_cast<const char *>(buffer) + got), bufferSize - got);
a227f47d
RG
837 if (res == 0) {
838 throw std::runtime_error("Error writing to TLS connection");
839 }
840 else if (res > 0) {
d0ae6360 841 got += static_cast<size_t>(res);
a227f47d
RG
842 }
843 else if (res < 0) {
844 if (gnutls_error_is_fatal(res)) {
d7aff6e6
RG
845 throw std::runtime_error("Error writing to TLS connection: " + std::string(gnutls_strerror(res)));
846 }
847 else if (res == GNUTLS_E_AGAIN) {
848 int result = waitForRWData(d_socket, false, writeTimeout, 0);
849 if (result <= 0) {
850 throw std::runtime_error("Error waiting to write to TLS connection: " + std::to_string(result));
851 }
852 }
853 else {
854 vinfolog("Non-fatal error while writing to TLS connection: %s", gnutls_strerror(res));
a227f47d 855 }
a227f47d
RG
856 }
857 }
858 while (got < bufferSize);
859
860 return got;
861 }
862
863 void close() override
864 {
865 if (d_conn) {
8dd7033b 866 gnutls_bye(d_conn.get(), GNUTLS_SHUT_WR);
a227f47d
RG
867 }
868 }
869
870private:
8dd7033b 871 std::unique_ptr<gnutls_session_int, void(*)(gnutls_session_t)> d_conn;
a227f47d
RG
872 std::shared_ptr<GnuTLSTicketsKey> d_ticketsKey;
873};
874
875class GnuTLSIOCtx: public TLSCtx
876{
877public:
8dd7033b 878 GnuTLSIOCtx(const TLSFrontend& fe): d_creds(std::unique_ptr<gnutls_certificate_credentials_st, void(*)(gnutls_certificate_credentials_t)>(nullptr, gnutls_certificate_free_credentials)), d_enableTickets(fe.d_enableTickets)
a227f47d 879 {
0ca5d025 880 int rc = 0;
a227f47d
RG
881 d_ticketsKeyRotationDelay = fe.d_ticketsKeyRotationDelay;
882
8dd7033b
RG
883 gnutls_certificate_credentials_t creds;
884 rc = gnutls_certificate_allocate_credentials(&creds);
0ca5d025
RG
885 if (rc != GNUTLS_E_SUCCESS) {
886 throw std::runtime_error("Error allocating credentials for TLS context on " + fe.d_addr.toStringWithPort() + ": " + gnutls_strerror(rc));
a227f47d
RG
887 }
888
8dd7033b
RG
889 d_creds = std::unique_ptr<gnutls_certificate_credentials_st, void(*)(gnutls_certificate_credentials_t)>(creds, gnutls_certificate_free_credentials);
890 creds = nullptr;
891
fa974ada 892 for (const auto& pair : fe.d_certKeyPairs) {
8dd7033b 893 rc = gnutls_certificate_set_x509_key_file(d_creds.get(), pair.first.c_str(), pair.second.c_str(), GNUTLS_X509_FMT_PEM);
fa974ada 894 if (rc != GNUTLS_E_SUCCESS) {
fa974ada
RG
895 throw std::runtime_error("Error loading certificate ('" + pair.first + "') and key ('" + pair.second + "') for TLS context on " + fe.d_addr.toStringWithPort() + ": " + gnutls_strerror(rc));
896 }
a227f47d
RG
897 }
898
899#if GNUTLS_VERSION_NUMBER >= 0x030600
8dd7033b 900 rc = gnutls_certificate_set_known_dh_params(d_creds.get(), GNUTLS_SEC_PARAM_HIGH);
0ca5d025 901 if (rc != GNUTLS_E_SUCCESS) {
0ca5d025 902 throw std::runtime_error("Error setting DH params for TLS context on " + fe.d_addr.toStringWithPort() + ": " + gnutls_strerror(rc));
a227f47d
RG
903 }
904#endif
905
0ca5d025
RG
906 rc = gnutls_priority_init(&d_priorityCache, fe.d_ciphers.empty() ? "NORMAL" : fe.d_ciphers.c_str(), nullptr);
907 if (rc != GNUTLS_E_SUCCESS) {
d26d6848 908 throw std::runtime_error("Error setting up TLS cipher preferences to '" + fe.d_ciphers + "' (" + gnutls_strerror(rc) + ") on " + fe.d_addr.toStringWithPort());
a227f47d
RG
909 }
910
1f65c18c
RG
911 pthread_rwlock_init(&d_lock, nullptr);
912
a227f47d
RG
913 try {
914 if (fe.d_ticketKeyFile.empty()) {
915 handleTicketsKeyRotation(time(nullptr));
916 }
917 else {
918 loadTicketsKeys(fe.d_ticketKeyFile);
919 }
920 }
921 catch(const std::runtime_error& e) {
1f65c18c 922 pthread_rwlock_destroy(&d_lock);
a227f47d
RG
923 throw std::runtime_error("Error generating tickets key for TLS context on " + fe.d_addr.toStringWithPort() + ": " + e.what());
924 }
925 }
926
927 virtual ~GnuTLSIOCtx() override
928 {
1f65c18c
RG
929 pthread_rwlock_destroy(&d_lock);
930
8dd7033b
RG
931 d_creds.reset();
932
a227f47d
RG
933 if (d_priorityCache) {
934 gnutls_priority_deinit(d_priorityCache);
935 }
936 }
937
938 std::unique_ptr<TLSConnection> getConnection(int socket, unsigned int timeout, time_t now) override
939 {
940 handleTicketsKeyRotation(now);
941
1f65c18c
RG
942 std::shared_ptr<GnuTLSTicketsKey> ticketsKey;
943 {
944 ReadLock rl(&d_lock);
945 ticketsKey = d_ticketsKey;
946 }
947
948 return std::unique_ptr<GnuTLSConnection>(new GnuTLSConnection(socket, timeout, d_creds.get(), d_priorityCache, ticketsKey, d_enableTickets));
a227f47d
RG
949 }
950
951 void rotateTicketsKey(time_t now) override
952 {
ba20dc97 953 if (!d_enableTickets) {
c8d7b468
RG
954 return;
955 }
956
a227f47d 957 auto newKey = std::make_shared<GnuTLSTicketsKey>();
1f65c18c
RG
958
959 {
960 WriteLock wl(&d_lock);
961 d_ticketsKey = newKey;
962 }
963
a227f47d 964 if (d_ticketsKeyRotationDelay > 0) {
2d29e6b7 965 d_ticketsKeyNextRotation = now + d_ticketsKeyRotationDelay;
a227f47d
RG
966 }
967 }
968
969 void loadTicketsKeys(const std::string& file) override
970 {
ba20dc97 971 if (!d_enableTickets) {
c8d7b468
RG
972 return;
973 }
974
a227f47d 975 auto newKey = std::make_shared<GnuTLSTicketsKey>(file);
1f65c18c
RG
976 {
977 WriteLock wl(&d_lock);
978 d_ticketsKey = newKey;
979 }
980
a227f47d
RG
981 if (d_ticketsKeyRotationDelay > 0) {
982 d_ticketsKeyNextRotation = time(nullptr) + d_ticketsKeyRotationDelay;
983 }
984 }
985
986 size_t getTicketsKeysCount() override
987 {
1f65c18c 988 ReadLock rl(&d_lock);
a227f47d
RG
989 return d_ticketsKey != nullptr ? 1 : 0;
990 }
991
992private:
8dd7033b 993 std::unique_ptr<gnutls_certificate_credentials_st, void(*)(gnutls_certificate_credentials_t)> d_creds;
a227f47d
RG
994 gnutls_priority_t d_priorityCache{nullptr};
995 std::shared_ptr<GnuTLSTicketsKey> d_ticketsKey{nullptr};
1f65c18c 996 pthread_rwlock_t d_lock;
ba20dc97 997 bool d_enableTickets{true};
a227f47d
RG
998};
999
1000#endif /* HAVE_GNUTLS */
1001
1002#endif /* HAVE_DNS_OVER_TLS */
1003
1004bool TLSFrontend::setupTLS()
1005{
1006#ifdef HAVE_DNS_OVER_TLS
1007 /* get the "best" available provider */
1008 if (!d_provider.empty()) {
1009#ifdef HAVE_GNUTLS
1010 if (d_provider == "gnutls") {
1011 d_ctx = std::make_shared<GnuTLSIOCtx>(*this);
1012 return true;
1013 }
1014#endif /* HAVE_GNUTLS */
1015#ifdef HAVE_LIBSSL
1016 if (d_provider == "openssl") {
1017 d_ctx = std::make_shared<OpenSSLTLSIOCtx>(*this);
1018 return true;
1019 }
1020#endif /* HAVE_LIBSSL */
1021 }
1022#ifdef HAVE_GNUTLS
1023 d_ctx = std::make_shared<GnuTLSIOCtx>(*this);
1024#else /* HAVE_GNUTLS */
1025#ifdef HAVE_LIBSSL
1026 d_ctx = std::make_shared<OpenSSLTLSIOCtx>(*this);
1027#endif /* HAVE_LIBSSL */
1028#endif /* HAVE_GNUTLS */
1029
1030#endif /* HAVE_DNS_OVER_TLS */
1031 return true;
1032}