1 /* SPDX-License-Identifier: LGPL-2.1+ */
3 #if !ENABLE_DNS_OVER_TLS || !DNS_OVER_TLS_USE_OPENSSL
4 #error This source file requires DNS-over-TLS to be enabled and OpenSSL to be available.
7 #include <openssl/bio.h>
8 #include <openssl/err.h>
12 #include "resolved-dns-stream.h"
13 #include "resolved-dnstls.h"
15 DEFINE_TRIVIAL_CLEANUP_FUNC(SSL
*, SSL_free
);
16 DEFINE_TRIVIAL_CLEANUP_FUNC(BIO
*, BIO_free
);
18 static int dnstls_flush_write_buffer(DnsStream
*stream
) {
22 assert(stream
->encrypted
);
24 if (stream
->dnstls_data
.write_buffer
->length
> 0) {
25 assert(stream
->dnstls_data
.write_buffer
->data
);
28 iov
[0] = IOVEC_MAKE(stream
->dnstls_data
.write_buffer
->data
,
29 stream
->dnstls_data
.write_buffer
->length
);
30 ss
= dns_stream_writev(stream
, iov
, 1, DNS_STREAM_WRITE_TLS_DATA
);
33 stream
->dnstls_events
|= EPOLLOUT
;
37 stream
->dnstls_data
.write_buffer
->length
-= ss
;
39 if (stream
->dnstls_data
.write_buffer
->length
> 0) {
40 memmove(stream
->dnstls_data
.write_buffer
->data
,
41 stream
->dnstls_data
.write_buffer
->data
+ ss
,
42 stream
->dnstls_data
.write_buffer
->length
);
43 stream
->dnstls_events
|= EPOLLOUT
;
52 int dnstls_stream_connect_tls(DnsStream
*stream
, DnsServer
*server
) {
53 _cleanup_(BIO_freep
) BIO
*rb
= NULL
, *wb
= NULL
;
54 _cleanup_(SSL_freep
) SSL
*s
= NULL
;
60 rb
= BIO_new_socket(stream
->fd
, 0);
64 wb
= BIO_new(BIO_s_mem());
68 BIO_get_mem_ptr(wb
, &stream
->dnstls_data
.write_buffer
);
70 s
= SSL_new(server
->dnstls_data
.ctx
);
74 SSL_set_connect_state(s
);
75 SSL_set_session(s
, server
->dnstls_data
.session
);
76 SSL_set_bio(s
, TAKE_PTR(rb
), TAKE_PTR(wb
));
79 stream
->dnstls_data
.handshake
= SSL_do_handshake(s
);
80 if (stream
->dnstls_data
.handshake
<= 0) {
81 error
= SSL_get_error(s
, stream
->dnstls_data
.handshake
);
82 if (!IN_SET(error
, SSL_ERROR_WANT_READ
, SSL_ERROR_WANT_WRITE
)) {
85 ERR_error_string_n(error
, errbuf
, sizeof(errbuf
));
86 log_debug("Failed to invoke SSL_do_handshake: %s", errbuf
);
91 stream
->encrypted
= true;
93 r
= dnstls_flush_write_buffer(stream
);
94 if (r
< 0 && r
!= -EAGAIN
)
97 stream
->dnstls_data
.ssl
= TAKE_PTR(s
);
102 void dnstls_stream_free(DnsStream
*stream
) {
104 assert(stream
->encrypted
);
106 if (stream
->dnstls_data
.ssl
)
107 SSL_free(stream
->dnstls_data
.ssl
);
110 int dnstls_stream_on_io(DnsStream
*stream
, uint32_t revents
) {
114 assert(stream
->encrypted
);
115 assert(stream
->dnstls_data
.ssl
);
117 /* Flush write buffer when requested by OpenSSL */
118 if ((revents
& EPOLLOUT
) && (stream
->dnstls_events
& EPOLLOUT
)) {
119 r
= dnstls_flush_write_buffer(stream
);
124 if (stream
->dnstls_data
.shutdown
) {
126 r
= SSL_shutdown(stream
->dnstls_data
.ssl
);
128 stream
->dnstls_events
= 0;
130 r
= dnstls_flush_write_buffer(stream
);
136 error
= SSL_get_error(stream
->dnstls_data
.ssl
, r
);
137 if (IN_SET(error
, SSL_ERROR_WANT_READ
, SSL_ERROR_WANT_WRITE
)) {
138 stream
->dnstls_events
= error
== SSL_ERROR_WANT_READ
? EPOLLIN
: EPOLLOUT
;
140 r
= dnstls_flush_write_buffer(stream
);
145 } else if (error
== SSL_ERROR_SYSCALL
) {
147 log_debug_errno(errno
, "Failed to invoke SSL_shutdown, ignoring: %m");
151 ERR_error_string_n(error
, errbuf
, sizeof(errbuf
));
152 log_debug("Failed to invoke SSL_shutdown, ignoring: %s", errbuf
);
156 stream
->dnstls_events
= 0;
157 stream
->dnstls_data
.shutdown
= false;
159 r
= dnstls_flush_write_buffer(stream
);
163 dns_stream_unref(stream
);
164 return DNSTLS_STREAM_CLOSED
;
165 } else if (stream
->dnstls_data
.handshake
<= 0) {
167 stream
->dnstls_data
.handshake
= SSL_do_handshake(stream
->dnstls_data
.ssl
);
168 if (stream
->dnstls_data
.handshake
<= 0) {
169 error
= SSL_get_error(stream
->dnstls_data
.ssl
, stream
->dnstls_data
.handshake
);
170 if (IN_SET(error
, SSL_ERROR_WANT_READ
, SSL_ERROR_WANT_WRITE
)) {
171 stream
->dnstls_events
= error
== SSL_ERROR_WANT_READ
? EPOLLIN
: EPOLLOUT
;
172 r
= dnstls_flush_write_buffer(stream
);
180 ERR_error_string_n(error
, errbuf
, sizeof(errbuf
));
181 return log_debug_errno(SYNTHETIC_ERRNO(ECONNREFUSED
),
182 "Failed to invoke SSL_do_handshake: %s",
187 stream
->dnstls_events
= 0;
188 r
= dnstls_flush_write_buffer(stream
);
196 int dnstls_stream_shutdown(DnsStream
*stream
, int error
) {
201 assert(stream
->encrypted
);
202 assert(stream
->dnstls_data
.ssl
);
204 if (stream
->server
) {
205 s
= SSL_get1_session(stream
->dnstls_data
.ssl
);
207 if (stream
->server
->dnstls_data
.session
)
208 SSL_SESSION_free(stream
->server
->dnstls_data
.session
);
210 stream
->server
->dnstls_data
.session
= s
;
214 if (error
== ETIMEDOUT
) {
216 r
= SSL_shutdown(stream
->dnstls_data
.ssl
);
218 if (!stream
->dnstls_data
.shutdown
) {
219 stream
->dnstls_data
.shutdown
= true;
220 dns_stream_ref(stream
);
223 stream
->dnstls_events
= 0;
225 r
= dnstls_flush_write_buffer(stream
);
231 ssl_error
= SSL_get_error(stream
->dnstls_data
.ssl
, r
);
232 if (IN_SET(ssl_error
, SSL_ERROR_WANT_READ
, SSL_ERROR_WANT_WRITE
)) {
233 stream
->dnstls_events
= ssl_error
== SSL_ERROR_WANT_READ
? EPOLLIN
: EPOLLOUT
;
234 r
= dnstls_flush_write_buffer(stream
);
235 if (r
< 0 && r
!= -EAGAIN
)
238 if (!stream
->dnstls_data
.shutdown
) {
239 stream
->dnstls_data
.shutdown
= true;
240 dns_stream_ref(stream
);
243 } else if (ssl_error
== SSL_ERROR_SYSCALL
) {
245 log_debug_errno(errno
, "Failed to invoke SSL_shutdown, ignoring: %m");
249 ERR_error_string_n(ssl_error
, errbuf
, sizeof(errbuf
));
250 log_debug("Failed to invoke SSL_shutdown, ignoring: %s", errbuf
);
254 stream
->dnstls_events
= 0;
255 r
= dnstls_flush_write_buffer(stream
);
263 ssize_t
dnstls_stream_write(DnsStream
*stream
, const char *buf
, size_t count
) {
268 assert(stream
->encrypted
);
269 assert(stream
->dnstls_data
.ssl
);
273 ss
= r
= SSL_write(stream
->dnstls_data
.ssl
, buf
, count
);
275 error
= SSL_get_error(stream
->dnstls_data
.ssl
, r
);
276 if (IN_SET(error
, SSL_ERROR_WANT_READ
, SSL_ERROR_WANT_WRITE
)) {
277 stream
->dnstls_events
= error
== SSL_ERROR_WANT_READ
? EPOLLIN
: EPOLLOUT
;
279 } else if (error
== SSL_ERROR_ZERO_RETURN
) {
280 stream
->dnstls_events
= 0;
285 ERR_error_string_n(error
, errbuf
, sizeof(errbuf
));
286 log_debug("Failed to invoke SSL_write: %s", errbuf
);
287 stream
->dnstls_events
= 0;
291 stream
->dnstls_events
= 0;
293 r
= dnstls_flush_write_buffer(stream
);
300 ssize_t
dnstls_stream_read(DnsStream
*stream
, void *buf
, size_t count
) {
305 assert(stream
->encrypted
);
306 assert(stream
->dnstls_data
.ssl
);
310 ss
= r
= SSL_read(stream
->dnstls_data
.ssl
, buf
, count
);
312 error
= SSL_get_error(stream
->dnstls_data
.ssl
, r
);
313 if (IN_SET(error
, SSL_ERROR_WANT_READ
, SSL_ERROR_WANT_WRITE
)) {
314 stream
->dnstls_events
= error
== SSL_ERROR_WANT_READ
? EPOLLIN
: EPOLLOUT
;
316 } else if (error
== SSL_ERROR_ZERO_RETURN
) {
317 stream
->dnstls_events
= 0;
322 ERR_error_string_n(error
, errbuf
, sizeof(errbuf
));
323 log_debug("Failed to invoke SSL_read: %s", errbuf
);
324 stream
->dnstls_events
= 0;
328 stream
->dnstls_events
= 0;
330 /* flush write buffer in cache of renegotiation */
331 r
= dnstls_flush_write_buffer(stream
);
338 void dnstls_server_init(DnsServer
*server
) {
341 server
->dnstls_data
.ctx
= SSL_CTX_new(TLS_client_method());
342 if (server
->dnstls_data
.ctx
) {
343 SSL_CTX_set_min_proto_version(server
->dnstls_data
.ctx
, TLS1_2_VERSION
);
344 SSL_CTX_set_options(server
->dnstls_data
.ctx
, SSL_OP_NO_COMPRESSION
);
348 void dnstls_server_free(DnsServer
*server
) {
351 if (server
->dnstls_data
.ctx
)
352 SSL_CTX_free(server
->dnstls_data
.ctx
);
354 if (server
->dnstls_data
.session
)
355 SSL_SESSION_free(server
->dnstls_data
.session
);