1 /* SPDX-License-Identifier: LGPL-2.1+ */
3 #if !ENABLE_DNS_OVER_TLS || !DNS_OVER_TLS_USE_OPENSSL
4 #error This source file requires DNS-over-TLS to be enabled and OpenSSL to be available.
7 #include <openssl/bio.h>
8 #include <openssl/err.h>
11 #include "resolved-dns-stream.h"
12 #include "resolved-dnstls.h"
14 DEFINE_TRIVIAL_CLEANUP_FUNC(SSL
*, SSL_free
);
15 DEFINE_TRIVIAL_CLEANUP_FUNC(BIO
*, BIO_free
);
17 static int dnstls_flush_write_buffer(DnsStream
*stream
) {
21 assert(stream
->encrypted
);
23 if (stream
->dnstls_data
.buffer_offset
< stream
->dnstls_data
.write_buffer
->length
) {
24 assert(stream
->dnstls_data
.write_buffer
->data
);
27 iov
[0] = IOVEC_MAKE(stream
->dnstls_data
.write_buffer
->data
+ stream
->dnstls_data
.buffer_offset
,
28 stream
->dnstls_data
.write_buffer
->length
- stream
->dnstls_data
.buffer_offset
);
29 ss
= dns_stream_writev(stream
, iov
, 1, DNS_STREAM_WRITE_TLS_DATA
);
32 stream
->dnstls_events
|= EPOLLOUT
;
36 stream
->dnstls_data
.buffer_offset
+= ss
;
38 if (stream
->dnstls_data
.buffer_offset
< stream
->dnstls_data
.write_buffer
->length
) {
39 stream
->dnstls_events
|= EPOLLOUT
;
42 BIO_reset(SSL_get_wbio(stream
->dnstls_data
.ssl
));
43 stream
->dnstls_data
.buffer_offset
= 0;
51 int dnstls_stream_connect_tls(DnsStream
*stream
, DnsServer
*server
) {
52 _cleanup_(BIO_freep
) BIO
*rb
= NULL
, *wb
= NULL
;
53 _cleanup_(SSL_freep
) SSL
*s
= NULL
;
59 rb
= BIO_new_socket(stream
->fd
, 0);
63 wb
= BIO_new(BIO_s_mem());
67 BIO_get_mem_ptr(wb
, &stream
->dnstls_data
.write_buffer
);
68 stream
->dnstls_data
.buffer_offset
= 0;
70 s
= SSL_new(server
->dnstls_data
.ctx
);
74 SSL_set_connect_state(s
);
75 SSL_set_session(s
, server
->dnstls_data
.session
);
76 SSL_set_bio(s
, TAKE_PTR(rb
), TAKE_PTR(wb
));
79 stream
->dnstls_data
.handshake
= SSL_do_handshake(s
);
80 if (stream
->dnstls_data
.handshake
<= 0) {
81 error
= SSL_get_error(s
, stream
->dnstls_data
.handshake
);
82 if (!IN_SET(error
, SSL_ERROR_WANT_READ
, SSL_ERROR_WANT_WRITE
)) {
85 ERR_error_string_n(error
, errbuf
, sizeof(errbuf
));
86 log_debug("Failed to invoke SSL_do_handshake: %s", errbuf
);
91 stream
->encrypted
= true;
92 stream
->dnstls_data
.ssl
= TAKE_PTR(s
);
94 r
= dnstls_flush_write_buffer(stream
);
95 if (r
< 0 && r
!= -EAGAIN
) {
96 SSL_free(TAKE_PTR(stream
->dnstls_data
.ssl
));
103 void dnstls_stream_free(DnsStream
*stream
) {
105 assert(stream
->encrypted
);
107 if (stream
->dnstls_data
.ssl
)
108 SSL_free(stream
->dnstls_data
.ssl
);
111 int dnstls_stream_on_io(DnsStream
*stream
, uint32_t revents
) {
115 assert(stream
->encrypted
);
116 assert(stream
->dnstls_data
.ssl
);
118 /* Flush write buffer when requested by OpenSSL */
119 if ((revents
& EPOLLOUT
) && (stream
->dnstls_events
& EPOLLOUT
)) {
120 r
= dnstls_flush_write_buffer(stream
);
125 if (stream
->dnstls_data
.shutdown
) {
127 r
= SSL_shutdown(stream
->dnstls_data
.ssl
);
129 stream
->dnstls_events
= 0;
131 r
= dnstls_flush_write_buffer(stream
);
137 error
= SSL_get_error(stream
->dnstls_data
.ssl
, r
);
138 if (IN_SET(error
, SSL_ERROR_WANT_READ
, SSL_ERROR_WANT_WRITE
)) {
139 stream
->dnstls_events
= error
== SSL_ERROR_WANT_READ
? EPOLLIN
: EPOLLOUT
;
141 r
= dnstls_flush_write_buffer(stream
);
146 } else if (error
== SSL_ERROR_SYSCALL
) {
148 log_debug_errno(errno
, "Failed to invoke SSL_shutdown, ignoring: %m");
152 ERR_error_string_n(error
, errbuf
, sizeof(errbuf
));
153 log_debug("Failed to invoke SSL_shutdown, ignoring: %s", errbuf
);
157 stream
->dnstls_events
= 0;
158 stream
->dnstls_data
.shutdown
= false;
160 r
= dnstls_flush_write_buffer(stream
);
164 dns_stream_unref(stream
);
165 return DNSTLS_STREAM_CLOSED
;
166 } else if (stream
->dnstls_data
.handshake
<= 0) {
168 stream
->dnstls_data
.handshake
= SSL_do_handshake(stream
->dnstls_data
.ssl
);
169 if (stream
->dnstls_data
.handshake
<= 0) {
170 error
= SSL_get_error(stream
->dnstls_data
.ssl
, stream
->dnstls_data
.handshake
);
171 if (IN_SET(error
, SSL_ERROR_WANT_READ
, SSL_ERROR_WANT_WRITE
)) {
172 stream
->dnstls_events
= error
== SSL_ERROR_WANT_READ
? EPOLLIN
: EPOLLOUT
;
173 r
= dnstls_flush_write_buffer(stream
);
181 ERR_error_string_n(error
, errbuf
, sizeof(errbuf
));
182 return log_debug_errno(SYNTHETIC_ERRNO(ECONNREFUSED
),
183 "Failed to invoke SSL_do_handshake: %s",
188 stream
->dnstls_events
= 0;
189 r
= dnstls_flush_write_buffer(stream
);
197 int dnstls_stream_shutdown(DnsStream
*stream
, int error
) {
202 assert(stream
->encrypted
);
203 assert(stream
->dnstls_data
.ssl
);
205 if (stream
->server
) {
206 s
= SSL_get1_session(stream
->dnstls_data
.ssl
);
208 if (stream
->server
->dnstls_data
.session
)
209 SSL_SESSION_free(stream
->server
->dnstls_data
.session
);
211 stream
->server
->dnstls_data
.session
= s
;
215 if (error
== ETIMEDOUT
) {
217 r
= SSL_shutdown(stream
->dnstls_data
.ssl
);
219 if (!stream
->dnstls_data
.shutdown
) {
220 stream
->dnstls_data
.shutdown
= true;
221 dns_stream_ref(stream
);
224 stream
->dnstls_events
= 0;
226 r
= dnstls_flush_write_buffer(stream
);
232 ssl_error
= SSL_get_error(stream
->dnstls_data
.ssl
, r
);
233 if (IN_SET(ssl_error
, SSL_ERROR_WANT_READ
, SSL_ERROR_WANT_WRITE
)) {
234 stream
->dnstls_events
= ssl_error
== SSL_ERROR_WANT_READ
? EPOLLIN
: EPOLLOUT
;
235 r
= dnstls_flush_write_buffer(stream
);
236 if (r
< 0 && r
!= -EAGAIN
)
239 if (!stream
->dnstls_data
.shutdown
) {
240 stream
->dnstls_data
.shutdown
= true;
241 dns_stream_ref(stream
);
244 } else if (ssl_error
== SSL_ERROR_SYSCALL
) {
246 log_debug_errno(errno
, "Failed to invoke SSL_shutdown, ignoring: %m");
250 ERR_error_string_n(ssl_error
, errbuf
, sizeof(errbuf
));
251 log_debug("Failed to invoke SSL_shutdown, ignoring: %s", errbuf
);
255 stream
->dnstls_events
= 0;
256 r
= dnstls_flush_write_buffer(stream
);
264 ssize_t
dnstls_stream_write(DnsStream
*stream
, const char *buf
, size_t count
) {
269 assert(stream
->encrypted
);
270 assert(stream
->dnstls_data
.ssl
);
274 ss
= r
= SSL_write(stream
->dnstls_data
.ssl
, buf
, count
);
276 error
= SSL_get_error(stream
->dnstls_data
.ssl
, r
);
277 if (IN_SET(error
, SSL_ERROR_WANT_READ
, SSL_ERROR_WANT_WRITE
)) {
278 stream
->dnstls_events
= error
== SSL_ERROR_WANT_READ
? EPOLLIN
: EPOLLOUT
;
280 } else if (error
== SSL_ERROR_ZERO_RETURN
) {
281 stream
->dnstls_events
= 0;
286 ERR_error_string_n(error
, errbuf
, sizeof(errbuf
));
287 log_debug("Failed to invoke SSL_write: %s", errbuf
);
288 stream
->dnstls_events
= 0;
292 stream
->dnstls_events
= 0;
294 r
= dnstls_flush_write_buffer(stream
);
301 ssize_t
dnstls_stream_read(DnsStream
*stream
, void *buf
, size_t count
) {
306 assert(stream
->encrypted
);
307 assert(stream
->dnstls_data
.ssl
);
311 ss
= r
= SSL_read(stream
->dnstls_data
.ssl
, buf
, count
);
313 error
= SSL_get_error(stream
->dnstls_data
.ssl
, r
);
314 if (IN_SET(error
, SSL_ERROR_WANT_READ
, SSL_ERROR_WANT_WRITE
)) {
315 stream
->dnstls_events
= error
== SSL_ERROR_WANT_READ
? EPOLLIN
: EPOLLOUT
;
317 } else if (error
== SSL_ERROR_ZERO_RETURN
) {
318 stream
->dnstls_events
= 0;
323 ERR_error_string_n(error
, errbuf
, sizeof(errbuf
));
324 log_debug("Failed to invoke SSL_read: %s", errbuf
);
325 stream
->dnstls_events
= 0;
329 stream
->dnstls_events
= 0;
331 /* flush write buffer in cache of renegotiation */
332 r
= dnstls_flush_write_buffer(stream
);
339 void dnstls_server_init(DnsServer
*server
) {
342 server
->dnstls_data
.ctx
= SSL_CTX_new(TLS_client_method());
343 if (server
->dnstls_data
.ctx
) {
344 SSL_CTX_set_min_proto_version(server
->dnstls_data
.ctx
, TLS1_2_VERSION
);
345 SSL_CTX_set_options(server
->dnstls_data
.ctx
, SSL_OP_NO_COMPRESSION
);
349 void dnstls_server_free(DnsServer
*server
) {
352 if (server
->dnstls_data
.ctx
)
353 SSL_CTX_free(server
->dnstls_data
.ctx
);
355 if (server
->dnstls_data
.session
)
356 SSL_SESSION_free(server
->dnstls_data
.session
);