]>
Commit | Line | Data |
---|---|---|
a227f47d RG |
1 | |
2 | #include "config.h" | |
3 | #include "dolog.hh" | |
4 | #include "iputils.hh" | |
5 | #include "lock.hh" | |
6 | #include "tcpiohandler.hh" | |
7 | ||
8 | #ifdef HAVE_LIBSODIUM | |
9 | #include <sodium.h> | |
10 | #endif /* HAVE_LIBSODIUM */ | |
11 | ||
12 | #ifdef HAVE_DNS_OVER_TLS | |
13 | #ifdef HAVE_LIBSSL | |
14 | #include <openssl/conf.h> | |
15 | #include <openssl/err.h> | |
16 | #include <openssl/rand.h> | |
17 | #include <openssl/ssl.h> | |
18 | ||
ede152ec | 19 | #include "libssl.hh" |
a227f47d | 20 | |
33a55a38 RG |
21 | class OpenSSLFrontendContext |
22 | { | |
23 | public: | |
24 | OpenSSLFrontendContext(const ComboAddress& addr, const TLSConfig& tlsConfig): d_ticketKeys(tlsConfig.d_numberOfTicketsKeys) | |
25 | { | |
26 | registerOpenSSLUser(); | |
27 | ||
28 | d_tlsCtx = libssl_init_server_context(tlsConfig, d_ocspResponses); | |
29 | if (!d_tlsCtx) { | |
30 | ERR_print_errors_fp(stderr); | |
31 | throw std::runtime_error("Error creating TLS context on " + addr.toStringWithPort()); | |
32 | } | |
33 | } | |
34 | ||
35 | void cleanup() | |
36 | { | |
37 | d_tlsCtx.reset(); | |
38 | ||
39 | unregisterOpenSSLUser(); | |
40 | } | |
41 | ||
42 | OpenSSLTLSTicketKeysRing d_ticketKeys; | |
43 | std::map<int, std::string> d_ocspResponses; | |
44 | std::unique_ptr<SSL_CTX, void(*)(SSL_CTX*)> d_tlsCtx{nullptr, SSL_CTX_free}; | |
0a530e9d | 45 | std::unique_ptr<FILE, int(*)(FILE*)> d_keyLogFile{nullptr, fclose}; |
33a55a38 RG |
46 | }; |
47 | ||
a227f47d RG |
48 | class OpenSSLTLSConnection: public TLSConnection |
49 | { | |
50 | public: | |
33a55a38 | 51 | OpenSSLTLSConnection(int socket, unsigned int timeout, std::shared_ptr<OpenSSLFrontendContext> feContext): d_feContext(feContext), d_conn(std::unique_ptr<SSL, void(*)(SSL*)>(SSL_new(d_feContext->d_tlsCtx.get()), SSL_free)), d_timeout(timeout) |
a227f47d RG |
52 | { |
53 | d_socket = socket; | |
a227f47d | 54 | |
b608e6c6 RG |
55 | if (!s_initTLSConnIndex.test_and_set()) { |
56 | /* not initialized yet */ | |
57 | s_tlsConnIndex = SSL_get_ex_new_index(0, nullptr, nullptr, nullptr, nullptr); | |
58 | if (s_tlsConnIndex == -1) { | |
59 | throw std::runtime_error("Error getting an index for TLS connection data"); | |
60 | } | |
61 | } | |
62 | ||
a227f47d RG |
63 | if (!d_conn) { |
64 | vinfolog("Error creating TLS object"); | |
65 | if (g_verbose) { | |
66 | ERR_print_errors_fp(stderr); | |
67 | } | |
68 | throw std::runtime_error("Error creating TLS object"); | |
69 | } | |
70 | ||
8dd7033b | 71 | if (!SSL_set_fd(d_conn.get(), d_socket)) { |
a227f47d RG |
72 | throw std::runtime_error("Error assigning socket"); |
73 | } | |
b608e6c6 RG |
74 | |
75 | SSL_set_ex_data(d_conn.get(), s_tlsConnIndex, this); | |
d0ae6360 RG |
76 | } |
77 | ||
78 | IOState convertIORequestToIOState(int res) const | |
79 | { | |
80 | int error = SSL_get_error(d_conn.get(), res); | |
81 | if (error == SSL_ERROR_WANT_READ) { | |
82 | return IOState::NeedRead; | |
83 | } | |
84 | else if (error == SSL_ERROR_WANT_WRITE) { | |
85 | return IOState::NeedWrite; | |
86 | } | |
87cdc63a | 87 | else if (error == SSL_ERROR_SYSCALL) { |
11102d05 | 88 | throw std::runtime_error("Error while processing TLS connection: " + std::string(strerror(errno))); |
87cdc63a | 89 | } |
d0ae6360 | 90 | else { |
11102d05 | 91 | throw std::runtime_error("Error while processing TLS connection: " + std::to_string(error)); |
d0ae6360 RG |
92 | } |
93 | } | |
94 | ||
95 | void handleIORequest(int res, unsigned int timeout) | |
96 | { | |
97 | auto state = convertIORequestToIOState(res); | |
98 | if (state == IOState::NeedRead) { | |
99 | res = waitForData(d_socket, timeout); | |
11102d05 RG |
100 | if (res == 0) { |
101 | throw std::runtime_error("Timeout while reading from TLS connection"); | |
102 | } | |
103 | else if (res < 0) { | |
104 | throw std::runtime_error("Error waiting to read from TLS connection"); | |
d0ae6360 RG |
105 | } |
106 | } | |
107 | else if (state == IOState::NeedWrite) { | |
108 | res = waitForRWData(d_socket, false, timeout, 0); | |
11102d05 RG |
109 | if (res == 0) { |
110 | throw std::runtime_error("Timeout while writing to TLS connection"); | |
111 | } | |
112 | else if (res < 0) { | |
d0ae6360 RG |
113 | throw std::runtime_error("Error waiting to write to TLS connection"); |
114 | } | |
115 | } | |
116 | } | |
117 | ||
3163698b | 118 | IOState tryHandshake() override |
d0ae6360 RG |
119 | { |
120 | int res = SSL_accept(d_conn.get()); | |
121 | if (res == 1) { | |
122 | return IOState::Done; | |
123 | } | |
124 | else if (res < 0) { | |
125 | return convertIORequestToIOState(res); | |
126 | } | |
127 | ||
128 | throw std::runtime_error("Error accepting TLS connection"); | |
129 | } | |
a227f47d | 130 | |
3163698b | 131 | void doHandshake() override |
d0ae6360 | 132 | { |
a227f47d RG |
133 | int res = 0; |
134 | do { | |
8dd7033b | 135 | res = SSL_accept(d_conn.get()); |
a227f47d | 136 | if (res < 0) { |
d0ae6360 | 137 | handleIORequest(res, d_timeout); |
a227f47d RG |
138 | } |
139 | } | |
140 | while (res < 0); | |
141 | ||
142 | if (res != 1) { | |
143 | throw std::runtime_error("Error accepting TLS connection"); | |
144 | } | |
145 | } | |
146 | ||
d0ae6360 | 147 | IOState tryWrite(std::vector<uint8_t>& buffer, size_t& pos, size_t toWrite) override |
a227f47d | 148 | { |
d0ae6360 RG |
149 | do { |
150 | int res = SSL_write(d_conn.get(), reinterpret_cast<const char *>(&buffer.at(pos)), static_cast<int>(toWrite - pos)); | |
11102d05 | 151 | if (res <= 0) { |
d0ae6360 RG |
152 | return convertIORequestToIOState(res); |
153 | } | |
154 | else { | |
155 | pos += static_cast<size_t>(res); | |
a227f47d RG |
156 | } |
157 | } | |
d0ae6360 RG |
158 | while (pos < toWrite); |
159 | return IOState::Done; | |
160 | } | |
161 | ||
162 | IOState tryRead(std::vector<uint8_t>& buffer, size_t& pos, size_t toRead) override | |
163 | { | |
164 | do { | |
165 | int res = SSL_read(d_conn.get(), reinterpret_cast<char *>(&buffer.at(pos)), static_cast<int>(toRead - pos)); | |
11102d05 | 166 | if (res <= 0) { |
d0ae6360 RG |
167 | return convertIORequestToIOState(res); |
168 | } | |
169 | else { | |
170 | pos += static_cast<size_t>(res); | |
171 | } | |
a227f47d | 172 | } |
d0ae6360 RG |
173 | while (pos < toRead); |
174 | return IOState::Done; | |
a227f47d RG |
175 | } |
176 | ||
177 | size_t read(void* buffer, size_t bufferSize, unsigned int readTimeout, unsigned int totalTimeout) override | |
178 | { | |
179 | size_t got = 0; | |
180 | time_t start = 0; | |
181 | unsigned int remainingTime = totalTimeout; | |
182 | if (totalTimeout) { | |
183 | start = time(nullptr); | |
184 | } | |
185 | ||
186 | do { | |
8dd7033b | 187 | int res = SSL_read(d_conn.get(), (reinterpret_cast<char *>(buffer) + got), static_cast<int>(bufferSize - got)); |
11102d05 | 188 | if (res <= 0) { |
a227f47d RG |
189 | handleIORequest(res, readTimeout); |
190 | } | |
191 | else { | |
d0ae6360 | 192 | got += static_cast<size_t>(res); |
a227f47d RG |
193 | } |
194 | ||
195 | if (totalTimeout) { | |
196 | time_t now = time(nullptr); | |
197 | unsigned int elapsed = now - start; | |
198 | if (now < start || elapsed >= remainingTime) { | |
199 | throw runtime_error("Timeout while reading data"); | |
200 | } | |
201 | start = now; | |
202 | remainingTime -= elapsed; | |
203 | } | |
204 | } | |
205 | while (got < bufferSize); | |
206 | ||
207 | return got; | |
208 | } | |
209 | ||
210 | size_t write(const void* buffer, size_t bufferSize, unsigned int writeTimeout) override | |
211 | { | |
212 | size_t got = 0; | |
213 | do { | |
8dd7033b | 214 | int res = SSL_write(d_conn.get(), (reinterpret_cast<const char *>(buffer) + got), static_cast<int>(bufferSize - got)); |
11102d05 | 215 | if (res <= 0) { |
a227f47d RG |
216 | handleIORequest(res, writeTimeout); |
217 | } | |
218 | else { | |
d0ae6360 | 219 | got += static_cast<size_t>(res); |
a227f47d RG |
220 | } |
221 | } | |
222 | while (got < bufferSize); | |
223 | ||
224 | return got; | |
225 | } | |
046bac5c | 226 | |
a227f47d RG |
227 | void close() override |
228 | { | |
229 | if (d_conn) { | |
8dd7033b | 230 | SSL_shutdown(d_conn.get()); |
a227f47d RG |
231 | } |
232 | } | |
233 | ||
bb3954f0 | 234 | std::string getServerNameIndication() const override |
046bac5c RG |
235 | { |
236 | if (d_conn) { | |
237 | const char* value = SSL_get_servername(d_conn.get(), TLSEXT_NAMETYPE_host_name); | |
238 | if (value) { | |
239 | return std::string(value); | |
240 | } | |
241 | } | |
242 | return std::string(); | |
243 | } | |
244 | ||
bb3954f0 RG |
245 | LibsslTLSVersion getTLSVersion() const override |
246 | { | |
247 | auto proto = SSL_version(d_conn.get()); | |
248 | switch (proto) { | |
249 | case TLS1_VERSION: | |
250 | return LibsslTLSVersion::TLS10; | |
251 | case TLS1_1_VERSION: | |
252 | return LibsslTLSVersion::TLS11; | |
253 | case TLS1_2_VERSION: | |
254 | return LibsslTLSVersion::TLS12; | |
255 | #ifdef TLS1_3_VERSION | |
256 | case TLS1_3_VERSION: | |
257 | return LibsslTLSVersion::TLS13; | |
258 | #endif /* TLS1_3_VERSION */ | |
259 | default: | |
260 | return LibsslTLSVersion::Unknown; | |
261 | } | |
262 | } | |
263 | ||
846b63bb RG |
264 | bool hasSessionBeenResumed() const override |
265 | { | |
266 | if (d_conn) { | |
267 | return SSL_session_reused(d_conn.get()) != 0; | |
268 | } | |
269 | return false; | |
270 | } | |
271 | ||
b608e6c6 | 272 | static int s_tlsConnIndex; |
846b63bb | 273 | |
a227f47d | 274 | private: |
b608e6c6 RG |
275 | static std::atomic_flag s_initTLSConnIndex; |
276 | ||
33a55a38 | 277 | std::shared_ptr<OpenSSLFrontendContext> d_feContext; |
8dd7033b | 278 | std::unique_ptr<SSL, void(*)(SSL*)> d_conn; |
d0ae6360 | 279 | unsigned int d_timeout; |
a227f47d RG |
280 | }; |
281 | ||
b608e6c6 RG |
282 | std::atomic_flag OpenSSLTLSConnection::s_initTLSConnIndex = ATOMIC_FLAG_INIT; |
283 | int OpenSSLTLSConnection::s_tlsConnIndex = -1; | |
284 | ||
a227f47d RG |
285 | class OpenSSLTLSIOCtx: public TLSCtx |
286 | { | |
287 | public: | |
33a55a38 | 288 | OpenSSLTLSIOCtx(TLSFrontend& fe) |
a227f47d | 289 | { |
33a55a38 | 290 | d_feContext = std::make_shared<OpenSSLFrontendContext>(fe.d_addr, fe.d_tlsConfig); |
a227f47d | 291 | |
33a55a38 | 292 | d_ticketsKeyRotationDelay = fe.d_tlsConfig.d_ticketsKeyRotationDelay; |
a227f47d | 293 | |
b54e94dc | 294 | if (fe.d_tlsConfig.d_enableTickets && fe.d_tlsConfig.d_numberOfTicketsKeys > 0) { |
4ecc5603 | 295 | /* use our own ticket keys handler so we can rotate them */ |
33a55a38 RG |
296 | SSL_CTX_set_tlsext_ticket_key_cb(d_feContext->d_tlsCtx.get(), &OpenSSLTLSIOCtx::ticketKeyCb); |
297 | libssl_set_ticket_key_callback_data(d_feContext->d_tlsCtx.get(), d_feContext.get()); | |
4ecc5603 | 298 | } |
0ef9ab19 | 299 | |
33a55a38 RG |
300 | if (!d_feContext->d_ocspResponses.empty()) { |
301 | SSL_CTX_set_tlsext_status_cb(d_feContext->d_tlsCtx.get(), &OpenSSLTLSIOCtx::ocspStaplingCb); | |
302 | SSL_CTX_set_tlsext_status_arg(d_feContext->d_tlsCtx.get(), &d_feContext->d_ocspResponses); | |
fa974ada | 303 | } |
a227f47d | 304 | |
33a55a38 | 305 | libssl_set_error_counters_callback(d_feContext->d_tlsCtx, &fe.d_tlsCounters); |
f34fdcc5 | 306 | |
0a530e9d RG |
307 | if (!fe.d_tlsConfig.d_keyLogFile.empty()) { |
308 | d_feContext->d_keyLogFile = libssl_set_key_log_file(d_feContext->d_tlsCtx, fe.d_tlsConfig.d_keyLogFile); | |
309 | } | |
310 | ||
a227f47d | 311 | try { |
b54e94dc | 312 | if (fe.d_tlsConfig.d_ticketKeyFile.empty()) { |
a227f47d RG |
313 | handleTicketsKeyRotation(time(nullptr)); |
314 | } | |
315 | else { | |
8dfec397 | 316 | OpenSSLTLSIOCtx::loadTicketsKeys(fe.d_tlsConfig.d_ticketKeyFile); |
a227f47d RG |
317 | } |
318 | } | |
319 | catch (const std::exception& e) { | |
a227f47d RG |
320 | throw; |
321 | } | |
322 | } | |
323 | ||
33a55a38 | 324 | ~OpenSSLTLSIOCtx() override |
a227f47d | 325 | { |
a227f47d RG |
326 | } |
327 | ||
328 | static int ticketKeyCb(SSL *s, unsigned char keyName[TLS_TICKETS_KEY_NAME_SIZE], unsigned char *iv, EVP_CIPHER_CTX *ectx, HMAC_CTX *hctx, int enc) | |
329 | { | |
33a55a38 | 330 | OpenSSLFrontendContext* ctx = reinterpret_cast<OpenSSLFrontendContext*>(libssl_get_ticket_key_callback_data(s)); |
a227f47d RG |
331 | if (ctx == nullptr) { |
332 | return -1; | |
333 | } | |
334 | ||
b608e6c6 RG |
335 | int ret = libssl_ticket_key_callback(s, ctx->d_ticketKeys, keyName, iv, ectx, hctx, enc); |
336 | if (enc == 0) { | |
337 | if (ret == 0 || ret == 2) { | |
338 | OpenSSLTLSConnection* conn = reinterpret_cast<OpenSSLTLSConnection*>(SSL_get_ex_data(s, OpenSSLTLSConnection::s_tlsConnIndex)); | |
339 | if (conn) { | |
340 | if (ret == 0) { | |
341 | conn->setUnknownTicketKey(); | |
342 | } | |
343 | else if (ret == 2) { | |
344 | conn->setResumedFromInactiveTicketKey(); | |
345 | } | |
346 | } | |
347 | } | |
348 | } | |
349 | ||
350 | return ret; | |
a227f47d RG |
351 | } |
352 | ||
be3183ed RG |
353 | static int ocspStaplingCb(SSL* ssl, void* arg) |
354 | { | |
355 | if (ssl == nullptr || arg == nullptr) { | |
356 | return SSL_TLSEXT_ERR_NOACK; | |
357 | } | |
358 | const auto ocspMap = reinterpret_cast<std::map<int, std::string>*>(arg); | |
359 | return libssl_ocsp_stapling_callback(ssl, *ocspMap); | |
360 | } | |
361 | ||
a227f47d RG |
362 | std::unique_ptr<TLSConnection> getConnection(int socket, unsigned int timeout, time_t now) override |
363 | { | |
364 | handleTicketsKeyRotation(now); | |
365 | ||
33a55a38 | 366 | return std::unique_ptr<OpenSSLTLSConnection>(new OpenSSLTLSConnection(socket, timeout, d_feContext)); |
a227f47d RG |
367 | } |
368 | ||
369 | void rotateTicketsKey(time_t now) override | |
370 | { | |
33a55a38 | 371 | d_feContext->d_ticketKeys.rotateTicketsKey(now); |
a227f47d RG |
372 | |
373 | if (d_ticketsKeyRotationDelay > 0) { | |
2d29e6b7 | 374 | d_ticketsKeyNextRotation = now + d_ticketsKeyRotationDelay; |
a227f47d RG |
375 | } |
376 | } | |
377 | ||
8dfec397 | 378 | void loadTicketsKeys(const std::string& keyFile) override final |
a227f47d | 379 | { |
33a55a38 | 380 | d_feContext->d_ticketKeys.loadTicketsKeys(keyFile); |
a227f47d RG |
381 | |
382 | if (d_ticketsKeyRotationDelay > 0) { | |
383 | d_ticketsKeyNextRotation = time(nullptr) + d_ticketsKeyRotationDelay; | |
384 | } | |
a227f47d RG |
385 | } |
386 | ||
387 | size_t getTicketsKeysCount() override | |
388 | { | |
33a55a38 | 389 | return d_feContext->d_ticketKeys.getKeysCount(); |
a227f47d RG |
390 | } |
391 | ||
392 | private: | |
33a55a38 | 393 | std::shared_ptr<OpenSSLFrontendContext> d_feContext; |
a227f47d RG |
394 | }; |
395 | ||
a227f47d RG |
396 | #endif /* HAVE_LIBSSL */ |
397 | ||
398 | #ifdef HAVE_GNUTLS | |
399 | #include <gnutls/gnutls.h> | |
400 | #include <gnutls/x509.h> | |
401 | ||
f56d26c9 | 402 | static void safe_memory_lock(void* data, size_t size) |
68aaaa06 | 403 | { |
7e81628b RG |
404 | #ifdef HAVE_LIBSODIUM |
405 | sodium_mlock(data, size); | |
406 | #endif | |
407 | } | |
408 | ||
f56d26c9 | 409 | static void safe_memory_release(void* data, size_t size) |
7e81628b RG |
410 | { |
411 | #ifdef HAVE_LIBSODIUM | |
412 | sodium_munlock(data, size); | |
413 | #elif defined(HAVE_EXPLICIT_BZERO) | |
eca7079b RG |
414 | explicit_bzero(data, size); |
415 | #elif defined(HAVE_EXPLICIT_MEMSET) | |
e1a1d350 | 416 | explicit_memset(data, 0, size); |
eca7079b | 417 | #elif defined(HAVE_GNUTLS_MEMSET) |
e1a1d350 | 418 | gnutls_memset(data, 0, size); |
7e81628b | 419 | #else |
eca7079b RG |
420 | /* shamelessly taken from Dovecot's src/lib/safe-memset.c */ |
421 | volatile unsigned int volatile_zero_idx = 0; | |
422 | volatile unsigned char *p = reinterpret_cast<volatile unsigned char *>(data); | |
423 | ||
424 | if (size == 0) | |
425 | return; | |
426 | ||
427 | do { | |
e1a1d350 RG |
428 | memset(data, 0, size); |
429 | } while (p[volatile_zero_idx] != 0); | |
7e81628b | 430 | #endif |
68aaaa06 | 431 | } |
68aaaa06 | 432 | |
a227f47d RG |
433 | class GnuTLSTicketsKey |
434 | { | |
435 | public: | |
436 | GnuTLSTicketsKey() | |
437 | { | |
438 | if (gnutls_session_ticket_key_generate(&d_key) != GNUTLS_E_SUCCESS) { | |
439 | throw std::runtime_error("Error generating tickets key for TLS context"); | |
440 | } | |
441 | ||
7e81628b | 442 | safe_memory_lock(d_key.data, d_key.size); |
a227f47d RG |
443 | } |
444 | ||
445 | GnuTLSTicketsKey(const std::string& keyFile) | |
446 | { | |
447 | /* to be sure we are loading the correct amount of data, which | |
448 | may change between versions, let's generate a correct key first */ | |
449 | if (gnutls_session_ticket_key_generate(&d_key) != GNUTLS_E_SUCCESS) { | |
450 | throw std::runtime_error("Error generating tickets key (before parsing key file) for TLS context"); | |
451 | } | |
452 | ||
7e81628b | 453 | safe_memory_lock(d_key.data, d_key.size); |
a227f47d RG |
454 | |
455 | try { | |
456 | ifstream file(keyFile); | |
457 | file.read(reinterpret_cast<char*>(d_key.data), d_key.size); | |
458 | ||
459 | if (file.fail()) { | |
460 | file.close(); | |
461 | throw std::runtime_error("Invalid GnuTLS tickets key file " + keyFile); | |
462 | } | |
463 | ||
464 | file.close(); | |
465 | } | |
466 | catch (const std::exception& e) { | |
7e81628b | 467 | safe_memory_release(d_key.data, d_key.size); |
a227f47d | 468 | gnutls_free(d_key.data); |
1f65c18c | 469 | d_key.data = nullptr; |
a227f47d RG |
470 | throw; |
471 | } | |
472 | } | |
473 | ||
474 | ~GnuTLSTicketsKey() | |
475 | { | |
476 | if (d_key.data != nullptr && d_key.size > 0) { | |
7e81628b | 477 | safe_memory_release(d_key.data, d_key.size); |
a227f47d RG |
478 | } |
479 | gnutls_free(d_key.data); | |
1f65c18c | 480 | d_key.data = nullptr; |
a227f47d RG |
481 | } |
482 | const gnutls_datum_t& getKey() const | |
483 | { | |
484 | return d_key; | |
485 | } | |
486 | ||
487 | private: | |
488 | gnutls_datum_t d_key{nullptr, 0}; | |
489 | }; | |
490 | ||
491 | class GnuTLSConnection: public TLSConnection | |
492 | { | |
493 | public: | |
494 | ||
8dd7033b | 495 | GnuTLSConnection(int socket, unsigned int timeout, const gnutls_certificate_credentials_t creds, const gnutls_priority_t priorityCache, std::shared_ptr<GnuTLSTicketsKey>& ticketsKey, bool enableTickets): d_conn(std::unique_ptr<gnutls_session_int, void(*)(gnutls_session_t)>(nullptr, gnutls_deinit)), d_ticketsKey(ticketsKey) |
a227f47d | 496 | { |
d0ae6360 | 497 | unsigned int sslOptions = GNUTLS_SERVER | GNUTLS_NONBLOCK; |
a227f47d | 498 | #ifdef GNUTLS_NO_SIGNAL |
c8d7b468 | 499 | sslOptions |= GNUTLS_NO_SIGNAL; |
a227f47d | 500 | #endif |
c8d7b468 RG |
501 | |
502 | d_socket = socket; | |
503 | ||
8dd7033b RG |
504 | gnutls_session_t conn; |
505 | if (gnutls_init(&conn, sslOptions) != GNUTLS_E_SUCCESS) { | |
a227f47d RG |
506 | throw std::runtime_error("Error creating TLS connection"); |
507 | } | |
508 | ||
8dd7033b RG |
509 | d_conn = std::unique_ptr<gnutls_session_int, void(*)(gnutls_session_t)>(conn, gnutls_deinit); |
510 | conn = nullptr; | |
511 | ||
512 | if (gnutls_credentials_set(d_conn.get(), GNUTLS_CRD_CERTIFICATE, creds) != GNUTLS_E_SUCCESS) { | |
a227f47d RG |
513 | throw std::runtime_error("Error setting certificate and key to TLS connection"); |
514 | } | |
515 | ||
8dd7033b | 516 | if (gnutls_priority_set(d_conn.get(), priorityCache) != GNUTLS_E_SUCCESS) { |
a227f47d RG |
517 | throw std::runtime_error("Error setting ciphers to TLS connection"); |
518 | } | |
519 | ||
ba20dc97 | 520 | if (enableTickets && d_ticketsKey) { |
a227f47d | 521 | const gnutls_datum_t& key = d_ticketsKey->getKey(); |
8dd7033b | 522 | if (gnutls_session_ticket_enable_server(d_conn.get(), &key) != GNUTLS_E_SUCCESS) { |
a227f47d RG |
523 | throw std::runtime_error("Error setting the tickets key to TLS connection"); |
524 | } | |
525 | } | |
526 | ||
8dd7033b | 527 | gnutls_transport_set_int(d_conn.get(), d_socket); |
a227f47d RG |
528 | |
529 | /* timeouts are in milliseconds */ | |
8dd7033b RG |
530 | gnutls_handshake_set_timeout(d_conn.get(), timeout * 1000); |
531 | gnutls_record_set_timeout(d_conn.get(), timeout * 1000); | |
d0ae6360 | 532 | } |
a227f47d | 533 | |
3163698b | 534 | void doHandshake() override |
d0ae6360 | 535 | { |
a227f47d RG |
536 | int ret = 0; |
537 | do { | |
8dd7033b | 538 | ret = gnutls_handshake(d_conn.get()); |
d0ae6360 RG |
539 | if (gnutls_error_is_fatal(ret) || ret == GNUTLS_E_WARNING_ALERT_RECEIVED) { |
540 | throw std::runtime_error("Error accepting a new connection"); | |
541 | } | |
542 | } | |
543 | while (ret < 0 && ret == GNUTLS_E_INTERRUPTED); | |
544 | } | |
545 | ||
3163698b | 546 | IOState tryHandshake() override |
d0ae6360 RG |
547 | { |
548 | int ret = 0; | |
549 | ||
550 | do { | |
551 | ret = gnutls_handshake(d_conn.get()); | |
552 | if (ret == GNUTLS_E_SUCCESS) { | |
553 | return IOState::Done; | |
554 | } | |
555 | else if (ret == GNUTLS_E_AGAIN) { | |
556 | return IOState::NeedRead; | |
557 | } | |
558 | else if (gnutls_error_is_fatal(ret) || ret == GNUTLS_E_WARNING_ALERT_RECEIVED) { | |
559 | throw std::runtime_error("Error accepting a new connection"); | |
560 | } | |
561 | } while (ret == GNUTLS_E_INTERRUPTED); | |
562 | ||
563 | throw std::runtime_error("Error accepting a new connection"); | |
564 | } | |
565 | ||
566 | IOState tryWrite(std::vector<uint8_t>& buffer, size_t& pos, size_t toWrite) override | |
567 | { | |
568 | do { | |
569 | ssize_t res = gnutls_record_send(d_conn.get(), reinterpret_cast<const char *>(&buffer.at(pos)), toWrite - pos); | |
570 | if (res == 0) { | |
571 | throw std::runtime_error("Error writing to TLS connection"); | |
572 | } | |
573 | else if (res > 0) { | |
574 | pos += static_cast<size_t>(res); | |
575 | } | |
576 | else if (res < 0) { | |
577 | if (gnutls_error_is_fatal(res)) { | |
11102d05 | 578 | throw std::runtime_error("Fatal error writing to TLS connection: " + std::string(gnutls_strerror(res))); |
d0ae6360 RG |
579 | } |
580 | else if (res == GNUTLS_E_AGAIN) { | |
581 | return IOState::NeedWrite; | |
582 | } | |
583 | warnlog("Warning, non-fatal error while writing to TLS connection: %s", gnutls_strerror(res)); | |
584 | } | |
585 | } | |
586 | while (pos < toWrite); | |
587 | return IOState::Done; | |
588 | } | |
589 | ||
590 | IOState tryRead(std::vector<uint8_t>& buffer, size_t& pos, size_t toRead) override | |
591 | { | |
592 | do { | |
593 | ssize_t res = gnutls_record_recv(d_conn.get(), reinterpret_cast<char *>(&buffer.at(pos)), toRead - pos); | |
594 | if (res == 0) { | |
595 | throw std::runtime_error("Error reading from TLS connection"); | |
596 | } | |
597 | else if (res > 0) { | |
598 | pos += static_cast<size_t>(res); | |
599 | } | |
600 | else if (res < 0) { | |
601 | if (gnutls_error_is_fatal(res)) { | |
11102d05 | 602 | throw std::runtime_error("Fatal error reading from TLS connection: " + std::string(gnutls_strerror(res))); |
d0ae6360 RG |
603 | } |
604 | else if (res == GNUTLS_E_AGAIN) { | |
605 | return IOState::NeedRead; | |
606 | } | |
607 | warnlog("Warning, non-fatal error while writing to TLS connection: %s", gnutls_strerror(res)); | |
608 | } | |
a227f47d | 609 | } |
d0ae6360 RG |
610 | while (pos < toRead); |
611 | return IOState::Done; | |
a227f47d RG |
612 | } |
613 | ||
a227f47d RG |
614 | size_t read(void* buffer, size_t bufferSize, unsigned int readTimeout, unsigned int totalTimeout) override |
615 | { | |
616 | size_t got = 0; | |
617 | time_t start = 0; | |
618 | unsigned int remainingTime = totalTimeout; | |
619 | if (totalTimeout) { | |
620 | start = time(nullptr); | |
621 | } | |
622 | ||
623 | do { | |
8dd7033b | 624 | ssize_t res = gnutls_record_recv(d_conn.get(), (reinterpret_cast<char *>(buffer) + got), bufferSize - got); |
a227f47d RG |
625 | if (res == 0) { |
626 | throw std::runtime_error("Error reading from TLS connection"); | |
627 | } | |
628 | else if (res > 0) { | |
d0ae6360 | 629 | got += static_cast<size_t>(res); |
a227f47d RG |
630 | } |
631 | else if (res < 0) { | |
632 | if (gnutls_error_is_fatal(res)) { | |
11102d05 | 633 | throw std::runtime_error("Fatal error reading from TLS connection: " + std::string(gnutls_strerror(res))); |
d7aff6e6 RG |
634 | } |
635 | else if (res == GNUTLS_E_AGAIN) { | |
636 | int result = waitForData(d_socket, readTimeout); | |
637 | if (result <= 0) { | |
11102d05 | 638 | throw std::runtime_error("Error while waiting to read from TLS connection: " + std::to_string(result)); |
d7aff6e6 RG |
639 | } |
640 | } | |
641 | else { | |
642 | vinfolog("Non-fatal error while reading from TLS connection: %s", gnutls_strerror(res)); | |
a227f47d | 643 | } |
a227f47d RG |
644 | } |
645 | ||
646 | if (totalTimeout) { | |
647 | time_t now = time(nullptr); | |
648 | unsigned int elapsed = now - start; | |
649 | if (now < start || elapsed >= remainingTime) { | |
650 | throw runtime_error("Timeout while reading data"); | |
651 | } | |
652 | start = now; | |
653 | remainingTime -= elapsed; | |
654 | } | |
655 | } | |
656 | while (got < bufferSize); | |
657 | ||
658 | return got; | |
659 | } | |
660 | ||
661 | size_t write(const void* buffer, size_t bufferSize, unsigned int writeTimeout) override | |
662 | { | |
663 | size_t got = 0; | |
664 | ||
665 | do { | |
8dd7033b | 666 | ssize_t res = gnutls_record_send(d_conn.get(), (reinterpret_cast<const char *>(buffer) + got), bufferSize - got); |
a227f47d RG |
667 | if (res == 0) { |
668 | throw std::runtime_error("Error writing to TLS connection"); | |
669 | } | |
670 | else if (res > 0) { | |
d0ae6360 | 671 | got += static_cast<size_t>(res); |
a227f47d RG |
672 | } |
673 | else if (res < 0) { | |
674 | if (gnutls_error_is_fatal(res)) { | |
11102d05 | 675 | throw std::runtime_error("Fatal error writing to TLS connection: " + std::string(gnutls_strerror(res))); |
d7aff6e6 RG |
676 | } |
677 | else if (res == GNUTLS_E_AGAIN) { | |
678 | int result = waitForRWData(d_socket, false, writeTimeout, 0); | |
679 | if (result <= 0) { | |
680 | throw std::runtime_error("Error waiting to write to TLS connection: " + std::to_string(result)); | |
681 | } | |
682 | } | |
683 | else { | |
684 | vinfolog("Non-fatal error while writing to TLS connection: %s", gnutls_strerror(res)); | |
a227f47d | 685 | } |
a227f47d RG |
686 | } |
687 | } | |
688 | while (got < bufferSize); | |
689 | ||
690 | return got; | |
691 | } | |
692 | ||
bb3954f0 | 693 | std::string getServerNameIndication() const override |
046bac5c RG |
694 | { |
695 | if (d_conn) { | |
696 | unsigned int type; | |
697 | size_t name_len = 256; | |
698 | std::string sni; | |
699 | sni.resize(name_len); | |
700 | ||
701 | int res = gnutls_server_name_get(d_conn.get(), const_cast<char*>(sni.c_str()), &name_len, &type, 0); | |
702 | if (res == GNUTLS_E_SUCCESS) { | |
703 | sni.resize(name_len); | |
704 | return sni; | |
705 | } | |
706 | } | |
707 | return std::string(); | |
708 | } | |
709 | ||
bb3954f0 RG |
710 | LibsslTLSVersion getTLSVersion() const override |
711 | { | |
712 | auto proto = gnutls_protocol_get_version(d_conn.get()); | |
713 | switch (proto) { | |
714 | case GNUTLS_TLS1_0: | |
715 | return LibsslTLSVersion::TLS10; | |
716 | case GNUTLS_TLS1_1: | |
717 | return LibsslTLSVersion::TLS11; | |
718 | case GNUTLS_TLS1_2: | |
719 | return LibsslTLSVersion::TLS12; | |
720 | #if GNUTLS_VERSION_NUMBER >= 0x030603 | |
721 | case GNUTLS_TLS1_3: | |
722 | return LibsslTLSVersion::TLS13; | |
723 | #endif /* GNUTLS_VERSION_NUMBER >= 0x030603 */ | |
724 | default: | |
725 | return LibsslTLSVersion::Unknown; | |
726 | } | |
727 | } | |
728 | ||
846b63bb RG |
729 | bool hasSessionBeenResumed() const override |
730 | { | |
731 | if (d_conn) { | |
732 | return gnutls_session_is_resumed(d_conn.get()) != 0; | |
733 | } | |
734 | return false; | |
735 | } | |
736 | ||
a227f47d RG |
737 | void close() override |
738 | { | |
739 | if (d_conn) { | |
8dd7033b | 740 | gnutls_bye(d_conn.get(), GNUTLS_SHUT_WR); |
a227f47d RG |
741 | } |
742 | } | |
743 | ||
744 | private: | |
8dd7033b | 745 | std::unique_ptr<gnutls_session_int, void(*)(gnutls_session_t)> d_conn; |
a227f47d RG |
746 | std::shared_ptr<GnuTLSTicketsKey> d_ticketsKey; |
747 | }; | |
748 | ||
749 | class GnuTLSIOCtx: public TLSCtx | |
750 | { | |
751 | public: | |
f34fdcc5 | 752 | GnuTLSIOCtx(TLSFrontend& fe): d_creds(std::unique_ptr<gnutls_certificate_credentials_st, void(*)(gnutls_certificate_credentials_t)>(nullptr, gnutls_certificate_free_credentials)), d_enableTickets(fe.d_tlsConfig.d_enableTickets) |
a227f47d | 753 | { |
0ca5d025 | 754 | int rc = 0; |
b54e94dc | 755 | d_ticketsKeyRotationDelay = fe.d_tlsConfig.d_ticketsKeyRotationDelay; |
a227f47d | 756 | |
8dd7033b RG |
757 | gnutls_certificate_credentials_t creds; |
758 | rc = gnutls_certificate_allocate_credentials(&creds); | |
0ca5d025 RG |
759 | if (rc != GNUTLS_E_SUCCESS) { |
760 | throw std::runtime_error("Error allocating credentials for TLS context on " + fe.d_addr.toStringWithPort() + ": " + gnutls_strerror(rc)); | |
a227f47d RG |
761 | } |
762 | ||
8dd7033b RG |
763 | d_creds = std::unique_ptr<gnutls_certificate_credentials_st, void(*)(gnutls_certificate_credentials_t)>(creds, gnutls_certificate_free_credentials); |
764 | creds = nullptr; | |
765 | ||
b54e94dc | 766 | for (const auto& pair : fe.d_tlsConfig.d_certKeyPairs) { |
8dd7033b | 767 | rc = gnutls_certificate_set_x509_key_file(d_creds.get(), pair.first.c_str(), pair.second.c_str(), GNUTLS_X509_FMT_PEM); |
fa974ada | 768 | if (rc != GNUTLS_E_SUCCESS) { |
fa974ada RG |
769 | throw std::runtime_error("Error loading certificate ('" + pair.first + "') and key ('" + pair.second + "') for TLS context on " + fe.d_addr.toStringWithPort() + ": " + gnutls_strerror(rc)); |
770 | } | |
a227f47d RG |
771 | } |
772 | ||
be3183ed | 773 | size_t count = 0; |
b54e94dc | 774 | for (const auto& file : fe.d_tlsConfig.d_ocspFiles) { |
be3183ed RG |
775 | rc = gnutls_certificate_set_ocsp_status_request_file(d_creds.get(), file.c_str(), count); |
776 | if (rc != GNUTLS_E_SUCCESS) { | |
b54e94dc | 777 | throw std::runtime_error("Error loading OCSP response from file '" + file + "' for certificate ('" + fe.d_tlsConfig.d_certKeyPairs.at(count).first + "') and key ('" + fe.d_tlsConfig.d_certKeyPairs.at(count).second + "') for TLS context on " + fe.d_addr.toStringWithPort() + ": " + gnutls_strerror(rc)); |
be3183ed RG |
778 | } |
779 | ++count; | |
780 | } | |
781 | ||
a227f47d | 782 | #if GNUTLS_VERSION_NUMBER >= 0x030600 |
8dd7033b | 783 | rc = gnutls_certificate_set_known_dh_params(d_creds.get(), GNUTLS_SEC_PARAM_HIGH); |
0ca5d025 | 784 | if (rc != GNUTLS_E_SUCCESS) { |
0ca5d025 | 785 | throw std::runtime_error("Error setting DH params for TLS context on " + fe.d_addr.toStringWithPort() + ": " + gnutls_strerror(rc)); |
a227f47d RG |
786 | } |
787 | #endif | |
788 | ||
b54e94dc | 789 | rc = gnutls_priority_init(&d_priorityCache, fe.d_tlsConfig.d_ciphers.empty() ? "NORMAL" : fe.d_tlsConfig.d_ciphers.c_str(), nullptr); |
0ca5d025 | 790 | if (rc != GNUTLS_E_SUCCESS) { |
b54e94dc | 791 | throw std::runtime_error("Error setting up TLS cipher preferences to '" + fe.d_tlsConfig.d_ciphers + "' (" + gnutls_strerror(rc) + ") on " + fe.d_addr.toStringWithPort()); |
a227f47d RG |
792 | } |
793 | ||
794 | try { | |
b54e94dc | 795 | if (fe.d_tlsConfig.d_ticketKeyFile.empty()) { |
a227f47d RG |
796 | handleTicketsKeyRotation(time(nullptr)); |
797 | } | |
798 | else { | |
8dfec397 | 799 | GnuTLSIOCtx::loadTicketsKeys(fe.d_tlsConfig.d_ticketKeyFile); |
a227f47d RG |
800 | } |
801 | } | |
802 | catch(const std::runtime_error& e) { | |
a227f47d RG |
803 | throw std::runtime_error("Error generating tickets key for TLS context on " + fe.d_addr.toStringWithPort() + ": " + e.what()); |
804 | } | |
805 | } | |
806 | ||
807 | virtual ~GnuTLSIOCtx() override | |
808 | { | |
8dd7033b RG |
809 | d_creds.reset(); |
810 | ||
a227f47d RG |
811 | if (d_priorityCache) { |
812 | gnutls_priority_deinit(d_priorityCache); | |
813 | } | |
814 | } | |
815 | ||
816 | std::unique_ptr<TLSConnection> getConnection(int socket, unsigned int timeout, time_t now) override | |
817 | { | |
818 | handleTicketsKeyRotation(now); | |
819 | ||
1f65c18c RG |
820 | std::shared_ptr<GnuTLSTicketsKey> ticketsKey; |
821 | { | |
822 | ReadLock rl(&d_lock); | |
823 | ticketsKey = d_ticketsKey; | |
824 | } | |
825 | ||
826 | return std::unique_ptr<GnuTLSConnection>(new GnuTLSConnection(socket, timeout, d_creds.get(), d_priorityCache, ticketsKey, d_enableTickets)); | |
a227f47d RG |
827 | } |
828 | ||
829 | void rotateTicketsKey(time_t now) override | |
830 | { | |
ba20dc97 | 831 | if (!d_enableTickets) { |
c8d7b468 RG |
832 | return; |
833 | } | |
834 | ||
a227f47d | 835 | auto newKey = std::make_shared<GnuTLSTicketsKey>(); |
1f65c18c RG |
836 | |
837 | { | |
838 | WriteLock wl(&d_lock); | |
839 | d_ticketsKey = newKey; | |
840 | } | |
841 | ||
a227f47d | 842 | if (d_ticketsKeyRotationDelay > 0) { |
2d29e6b7 | 843 | d_ticketsKeyNextRotation = now + d_ticketsKeyRotationDelay; |
a227f47d RG |
844 | } |
845 | } | |
846 | ||
8dfec397 | 847 | void loadTicketsKeys(const std::string& file) override final |
a227f47d | 848 | { |
ba20dc97 | 849 | if (!d_enableTickets) { |
c8d7b468 RG |
850 | return; |
851 | } | |
852 | ||
a227f47d | 853 | auto newKey = std::make_shared<GnuTLSTicketsKey>(file); |
1f65c18c RG |
854 | { |
855 | WriteLock wl(&d_lock); | |
856 | d_ticketsKey = newKey; | |
857 | } | |
858 | ||
a227f47d RG |
859 | if (d_ticketsKeyRotationDelay > 0) { |
860 | d_ticketsKeyNextRotation = time(nullptr) + d_ticketsKeyRotationDelay; | |
861 | } | |
862 | } | |
863 | ||
864 | size_t getTicketsKeysCount() override | |
865 | { | |
1f65c18c | 866 | ReadLock rl(&d_lock); |
a227f47d RG |
867 | return d_ticketsKey != nullptr ? 1 : 0; |
868 | } | |
869 | ||
870 | private: | |
8dd7033b | 871 | std::unique_ptr<gnutls_certificate_credentials_st, void(*)(gnutls_certificate_credentials_t)> d_creds; |
a227f47d RG |
872 | gnutls_priority_t d_priorityCache{nullptr}; |
873 | std::shared_ptr<GnuTLSTicketsKey> d_ticketsKey{nullptr}; | |
f0941861 | 874 | ReadWriteLock d_lock; |
ba20dc97 | 875 | bool d_enableTickets{true}; |
a227f47d RG |
876 | }; |
877 | ||
878 | #endif /* HAVE_GNUTLS */ | |
879 | ||
880 | #endif /* HAVE_DNS_OVER_TLS */ | |
881 | ||
882 | bool TLSFrontend::setupTLS() | |
883 | { | |
884 | #ifdef HAVE_DNS_OVER_TLS | |
885 | /* get the "best" available provider */ | |
886 | if (!d_provider.empty()) { | |
887 | #ifdef HAVE_GNUTLS | |
888 | if (d_provider == "gnutls") { | |
889 | d_ctx = std::make_shared<GnuTLSIOCtx>(*this); | |
890 | return true; | |
891 | } | |
892 | #endif /* HAVE_GNUTLS */ | |
893 | #ifdef HAVE_LIBSSL | |
894 | if (d_provider == "openssl") { | |
895 | d_ctx = std::make_shared<OpenSSLTLSIOCtx>(*this); | |
896 | return true; | |
897 | } | |
898 | #endif /* HAVE_LIBSSL */ | |
899 | } | |
a227f47d RG |
900 | #ifdef HAVE_LIBSSL |
901 | d_ctx = std::make_shared<OpenSSLTLSIOCtx>(*this); | |
d94702d6 RG |
902 | #else /* HAVE_LIBSSL */ |
903 | #ifdef HAVE_GNUTLS | |
904 | d_ctx = std::make_shared<GnuTLSIOCtx>(*this); | |
a227f47d | 905 | #endif /* HAVE_GNUTLS */ |
d94702d6 | 906 | #endif /* HAVE_LIBSSL */ |
a227f47d RG |
907 | |
908 | #endif /* HAVE_DNS_OVER_TLS */ | |
909 | return true; | |
910 | } |