resolved: add missing error code check when initializing DNS-over-TLS
[thirdparty/systemd.git] / src / resolve / resolved-dnstls-openssl.c
1 /* SPDX-License-Identifier: LGPL-2.1+ */
2
3 #if !ENABLE_DNS_OVER_TLS || !DNS_OVER_TLS_USE_OPENSSL
4 #error This source file requires DNS-over-TLS to be enabled and OpenSSL to be available.
5 #endif
6
7 #include <openssl/bio.h>
8 #include <openssl/err.h>
9
10 #include "io-util.h"
11 #include "resolved-dns-stream.h"
12 #include "resolved-dnstls.h"
13
14 DEFINE_TRIVIAL_CLEANUP_FUNC(SSL*, SSL_free);
15 DEFINE_TRIVIAL_CLEANUP_FUNC(BIO*, BIO_free);
16
17 static int dnstls_flush_write_buffer(DnsStream *stream) {
18         ssize_t ss;
19
20         assert(stream);
21         assert(stream->encrypted);
22
23         if (stream->dnstls_data.buffer_offset < stream->dnstls_data.write_buffer->length) {
24                 assert(stream->dnstls_data.write_buffer->data);
25
26                 struct iovec iov[1];
27                 iov[0] = IOVEC_MAKE(stream->dnstls_data.write_buffer->data + stream->dnstls_data.buffer_offset,
28                                     stream->dnstls_data.write_buffer->length - stream->dnstls_data.buffer_offset);
29                 ss = dns_stream_writev(stream, iov, 1, DNS_STREAM_WRITE_TLS_DATA);
30                 if (ss < 0) {
31                         if (ss == -EAGAIN)
32                                 stream->dnstls_events |= EPOLLOUT;
33
34                         return ss;
35                 } else {
36                         stream->dnstls_data.buffer_offset += ss;
37
38                         if (stream->dnstls_data.buffer_offset < stream->dnstls_data.write_buffer->length) {
39                                 stream->dnstls_events |= EPOLLOUT;
40                                 return -EAGAIN;
41                         } else {
42                                 BIO_reset(SSL_get_wbio(stream->dnstls_data.ssl));
43                                 stream->dnstls_data.buffer_offset = 0;
44                         }
45                 }
46         }
47
48         return 0;
49 }
50
51 int dnstls_stream_connect_tls(DnsStream *stream, DnsServer *server) {
52         _cleanup_(BIO_freep) BIO *rb = NULL, *wb = NULL;
53         _cleanup_(SSL_freep) SSL *s = NULL;
54         int error, r;
55
56         assert(stream);
57         assert(stream->manager);
58         assert(server);
59
60         rb = BIO_new_socket(stream->fd, 0);
61         if (!rb)
62                 return -ENOMEM;
63
64         wb = BIO_new(BIO_s_mem());
65         if (!wb)
66                 return -ENOMEM;
67
68         BIO_get_mem_ptr(wb, &stream->dnstls_data.write_buffer);
69         stream->dnstls_data.buffer_offset = 0;
70
71         s = SSL_new(stream->manager->dnstls_data.ctx);
72         if (!s)
73                 return -ENOMEM;
74
75         SSL_set_connect_state(s);
76         SSL_set_session(s, server->dnstls_data.session);
77         SSL_set_bio(s, TAKE_PTR(rb), TAKE_PTR(wb));
78
79         ERR_clear_error();
80         stream->dnstls_data.handshake = SSL_do_handshake(s);
81         if (stream->dnstls_data.handshake <= 0) {
82                 error = SSL_get_error(s, stream->dnstls_data.handshake);
83                 if (!IN_SET(error, SSL_ERROR_WANT_READ, SSL_ERROR_WANT_WRITE)) {
84                         char errbuf[256];
85
86                         ERR_error_string_n(error, errbuf, sizeof(errbuf));
87                         log_debug("Failed to invoke SSL_do_handshake: %s", errbuf);
88                         return -ECONNREFUSED;
89                 }
90         }
91
92         stream->encrypted = true;
93         stream->dnstls_data.ssl = TAKE_PTR(s);
94
95         r = dnstls_flush_write_buffer(stream);
96         if (r < 0 && r != -EAGAIN) {
97                 SSL_free(TAKE_PTR(stream->dnstls_data.ssl));
98                 return r;
99         }
100
101         return 0;
102 }
103
104 void dnstls_stream_free(DnsStream *stream) {
105         assert(stream);
106         assert(stream->encrypted);
107
108         if (stream->dnstls_data.ssl)
109                 SSL_free(stream->dnstls_data.ssl);
110 }
111
112 int dnstls_stream_on_io(DnsStream *stream, uint32_t revents) {
113         int error, r;
114
115         assert(stream);
116         assert(stream->encrypted);
117         assert(stream->dnstls_data.ssl);
118
119         /* Flush write buffer when requested by OpenSSL */
120         if ((revents & EPOLLOUT) && (stream->dnstls_events & EPOLLOUT)) {
121                 r = dnstls_flush_write_buffer(stream);
122                 if (r < 0)
123                         return r;
124         }
125
126         if (stream->dnstls_data.shutdown) {
127                 ERR_clear_error();
128                 r = SSL_shutdown(stream->dnstls_data.ssl);
129                 if (r == 0) {
130                         stream->dnstls_events = 0;
131
132                         r = dnstls_flush_write_buffer(stream);
133                         if (r < 0)
134                                 return r;
135
136                         return -EAGAIN;
137                 } else if (r < 0) {
138                         error = SSL_get_error(stream->dnstls_data.ssl, r);
139                         if (IN_SET(error, SSL_ERROR_WANT_READ, SSL_ERROR_WANT_WRITE)) {
140                                 stream->dnstls_events = error == SSL_ERROR_WANT_READ ? EPOLLIN : EPOLLOUT;
141
142                                 r = dnstls_flush_write_buffer(stream);
143                                 if (r < 0)
144                                         return r;
145
146                                 return -EAGAIN;
147                         } else if (error == SSL_ERROR_SYSCALL) {
148                                 if (errno > 0)
149                                         log_debug_errno(errno, "Failed to invoke SSL_shutdown, ignoring: %m");
150                         } else {
151                                 char errbuf[256];
152
153                                 ERR_error_string_n(error, errbuf, sizeof(errbuf));
154                                 log_debug("Failed to invoke SSL_shutdown, ignoring: %s", errbuf);
155                         }
156                 }
157
158                 stream->dnstls_events = 0;
159                 stream->dnstls_data.shutdown = false;
160
161                 r = dnstls_flush_write_buffer(stream);
162                 if (r < 0)
163                         return r;
164
165                 dns_stream_unref(stream);
166                 return DNSTLS_STREAM_CLOSED;
167         } else if (stream->dnstls_data.handshake <= 0) {
168                 ERR_clear_error();
169                 stream->dnstls_data.handshake = SSL_do_handshake(stream->dnstls_data.ssl);
170                 if (stream->dnstls_data.handshake <= 0) {
171                         error = SSL_get_error(stream->dnstls_data.ssl, stream->dnstls_data.handshake);
172                         if (IN_SET(error, SSL_ERROR_WANT_READ, SSL_ERROR_WANT_WRITE)) {
173                                 stream->dnstls_events = error == SSL_ERROR_WANT_READ ? EPOLLIN : EPOLLOUT;
174                                 r = dnstls_flush_write_buffer(stream);
175                                 if (r < 0)
176                                         return r;
177
178                                 return -EAGAIN;
179                         } else {
180                                 char errbuf[256];
181
182                                 ERR_error_string_n(error, errbuf, sizeof(errbuf));
183                                 return log_debug_errno(SYNTHETIC_ERRNO(ECONNREFUSED),
184                                                        "Failed to invoke SSL_do_handshake: %s",
185                                                        errbuf);
186                         }
187                 }
188
189                 stream->dnstls_events = 0;
190                 r = dnstls_flush_write_buffer(stream);
191                 if (r < 0)
192                         return r;
193         }
194
195         return 0;
196 }
197
198 int dnstls_stream_shutdown(DnsStream *stream, int error) {
199         int ssl_error, r;
200         SSL_SESSION *s;
201
202         assert(stream);
203         assert(stream->encrypted);
204         assert(stream->dnstls_data.ssl);
205
206         if (stream->server) {
207                 s = SSL_get1_session(stream->dnstls_data.ssl);
208                 if (s) {
209                         if (stream->server->dnstls_data.session)
210                                 SSL_SESSION_free(stream->server->dnstls_data.session);
211
212                         stream->server->dnstls_data.session = s;
213                 }
214         }
215
216         if (error == ETIMEDOUT) {
217                 ERR_clear_error();
218                 r = SSL_shutdown(stream->dnstls_data.ssl);
219                 if (r == 0) {
220                         if (!stream->dnstls_data.shutdown) {
221                                 stream->dnstls_data.shutdown = true;
222                                 dns_stream_ref(stream);
223                         }
224
225                         stream->dnstls_events = 0;
226
227                         r = dnstls_flush_write_buffer(stream);
228                         if (r < 0)
229                                 return r;
230
231                         return -EAGAIN;
232                 } else if (r < 0) {
233                         ssl_error = SSL_get_error(stream->dnstls_data.ssl, r);
234                         if (IN_SET(ssl_error, SSL_ERROR_WANT_READ, SSL_ERROR_WANT_WRITE)) {
235                                 stream->dnstls_events = ssl_error == SSL_ERROR_WANT_READ ? EPOLLIN : EPOLLOUT;
236                                 r = dnstls_flush_write_buffer(stream);
237                                 if (r < 0 && r != -EAGAIN)
238                                         return r;
239
240                                 if (!stream->dnstls_data.shutdown) {
241                                         stream->dnstls_data.shutdown = true;
242                                         dns_stream_ref(stream);
243                                 }
244                                 return -EAGAIN;
245                         } else if (ssl_error == SSL_ERROR_SYSCALL) {
246                                 if (errno > 0)
247                                         log_debug_errno(errno, "Failed to invoke SSL_shutdown, ignoring: %m");
248                         } else {
249                                 char errbuf[256];
250
251                                 ERR_error_string_n(ssl_error, errbuf, sizeof(errbuf));
252                                 log_debug("Failed to invoke SSL_shutdown, ignoring: %s", errbuf);
253                         }
254                 }
255
256                 stream->dnstls_events = 0;
257                 r = dnstls_flush_write_buffer(stream);
258                 if (r < 0)
259                         return r;
260         }
261
262         return 0;
263 }
264
265 ssize_t dnstls_stream_write(DnsStream *stream, const char *buf, size_t count) {
266         int error, r;
267         ssize_t ss;
268
269         assert(stream);
270         assert(stream->encrypted);
271         assert(stream->dnstls_data.ssl);
272         assert(buf);
273
274         ERR_clear_error();
275         ss = r = SSL_write(stream->dnstls_data.ssl, buf, count);
276         if (r <= 0) {
277                 error = SSL_get_error(stream->dnstls_data.ssl, r);
278                 if (IN_SET(error, SSL_ERROR_WANT_READ, SSL_ERROR_WANT_WRITE)) {
279                         stream->dnstls_events = error == SSL_ERROR_WANT_READ ? EPOLLIN : EPOLLOUT;
280                         ss = -EAGAIN;
281                 } else if (error == SSL_ERROR_ZERO_RETURN) {
282                         stream->dnstls_events = 0;
283                         ss = 0;
284                 } else {
285                         char errbuf[256];
286
287                         ERR_error_string_n(error, errbuf, sizeof(errbuf));
288                         log_debug("Failed to invoke SSL_write: %s", errbuf);
289                         stream->dnstls_events = 0;
290                         ss = -EPIPE;
291                 }
292         } else
293                 stream->dnstls_events = 0;
294
295         r = dnstls_flush_write_buffer(stream);
296         if (r < 0)
297                 return r;
298
299         return ss;
300 }
301
302 ssize_t dnstls_stream_read(DnsStream *stream, void *buf, size_t count) {
303         int error, r;
304         ssize_t ss;
305
306         assert(stream);
307         assert(stream->encrypted);
308         assert(stream->dnstls_data.ssl);
309         assert(buf);
310
311         ERR_clear_error();
312         ss = r = SSL_read(stream->dnstls_data.ssl, buf, count);
313         if (r <= 0) {
314                 error = SSL_get_error(stream->dnstls_data.ssl, r);
315                 if (IN_SET(error, SSL_ERROR_WANT_READ, SSL_ERROR_WANT_WRITE)) {
316                         stream->dnstls_events = error == SSL_ERROR_WANT_READ ? EPOLLIN : EPOLLOUT;
317                         ss = -EAGAIN;
318                 } else if (error == SSL_ERROR_ZERO_RETURN) {
319                         stream->dnstls_events = 0;
320                         ss = 0;
321                 } else {
322                         char errbuf[256];
323
324                         ERR_error_string_n(error, errbuf, sizeof(errbuf));
325                         log_debug("Failed to invoke SSL_read: %s", errbuf);
326                         stream->dnstls_events = 0;
327                         ss = -EPIPE;
328                 }
329         } else
330                 stream->dnstls_events = 0;
331
332         /* flush write buffer in cache of renegotiation */
333         r = dnstls_flush_write_buffer(stream);
334         if (r < 0)
335                 return r;
336
337         return ss;
338 }
339
340 void dnstls_server_free(DnsServer *server) {
341         assert(server);
342
343         if (server->dnstls_data.session)
344                 SSL_SESSION_free(server->dnstls_data.session);
345 }
346
347 int dnstls_manager_init(Manager *manager) {
348         int r;
349         assert(manager);
350
351         ERR_load_crypto_strings();
352         SSL_load_error_strings();
353         manager->dnstls_data.ctx = SSL_CTX_new(TLS_client_method());
354
355         if (!manager->dnstls_data.ctx)
356                 return -ENOMEM;
357
358         SSL_CTX_set_min_proto_version(manager->dnstls_data.ctx, TLS1_2_VERSION);
359         SSL_CTX_set_options(manager->dnstls_data.ctx, SSL_OP_NO_COMPRESSION);
360
361         return 0;
362 }
363
364 void dnstls_manager_free(Manager *manager) {
365         assert(manager);
366
367         if (manager->dnstls_data.ctx)
368                 SSL_CTX_free(manager->dnstls_data.ctx);
369 }