1 /* SPDX-License-Identifier: LGPL-2.1+ */
3 #if !ENABLE_DNS_OVER_TLS || !DNS_OVER_TLS_USE_OPENSSL
4 #error This source file requires DNS-over-TLS to be enabled and OpenSSL to be available.
7 #include <openssl/bio.h>
8 #include <openssl/err.h>
10 #include "resolved-dns-stream.h"
11 #include "resolved-dnstls.h"
13 DEFINE_TRIVIAL_CLEANUP_FUNC(SSL
*, SSL_free
);
14 DEFINE_TRIVIAL_CLEANUP_FUNC(BIO
*, BIO_free
);
16 static int dnstls_flush_write_buffer(DnsStream
*stream
) {
20 assert(stream
->encrypted
);
22 if (stream
->dnstls_data
.write_buffer
->length
> 0) {
23 assert(stream
->dnstls_data
.write_buffer
->data
);
26 iov
[0].iov_base
= stream
->dnstls_data
.write_buffer
->data
;
27 iov
[0].iov_len
= stream
->dnstls_data
.write_buffer
->length
;
28 ss
= dns_stream_writev(stream
, iov
, 1, DNS_STREAM_WRITE_TLS_DATA
);
31 stream
->dnstls_events
|= EPOLLOUT
;
35 stream
->dnstls_data
.write_buffer
->length
-= ss
;
36 stream
->dnstls_data
.write_buffer
->data
+= ss
;
38 if (stream
->dnstls_data
.write_buffer
->length
> 0) {
39 stream
->dnstls_events
|= EPOLLOUT
;
48 int dnstls_stream_connect_tls(DnsStream
*stream
, DnsServer
*server
) {
49 _cleanup_(BIO_freep
) BIO
*rb
= NULL
, *wb
= NULL
;
50 _cleanup_(SSL_freep
) SSL
*s
= NULL
;
56 rb
= BIO_new_socket(stream
->fd
, 0);
60 wb
= BIO_new(BIO_s_mem());
64 BIO_get_mem_ptr(wb
, &stream
->dnstls_data
.write_buffer
);
66 s
= SSL_new(server
->dnstls_data
.ctx
);
70 SSL_set_connect_state(s
);
71 SSL_set_session(s
, server
->dnstls_data
.session
);
72 SSL_set_bio(s
, TAKE_PTR(rb
), TAKE_PTR(wb
));
75 stream
->dnstls_data
.handshake
= SSL_do_handshake(s
);
76 if (stream
->dnstls_data
.handshake
<= 0) {
77 error
= SSL_get_error(s
, stream
->dnstls_data
.handshake
);
78 if (!IN_SET(error
, SSL_ERROR_WANT_READ
, SSL_ERROR_WANT_WRITE
)) {
81 ERR_error_string_n(error
, errbuf
, sizeof(errbuf
));
82 log_debug("Failed to invoke SSL_do_handshake: %s", errbuf
);
87 stream
->encrypted
= true;
89 r
= dnstls_flush_write_buffer(stream
);
90 if (r
< 0 && r
!= -EAGAIN
)
93 stream
->dnstls_data
.ssl
= TAKE_PTR(s
);
98 void dnstls_stream_free(DnsStream
*stream
) {
100 assert(stream
->encrypted
);
102 if (stream
->dnstls_data
.ssl
)
103 SSL_free(stream
->dnstls_data
.ssl
);
106 int dnstls_stream_on_io(DnsStream
*stream
, uint32_t revents
) {
110 assert(stream
->encrypted
);
111 assert(stream
->dnstls_data
.ssl
);
113 /* Flush write buffer when requested by OpenSSL */
114 if ((revents
& EPOLLOUT
) && (stream
->dnstls_events
& EPOLLOUT
)) {
115 r
= dnstls_flush_write_buffer(stream
);
120 if (stream
->dnstls_data
.shutdown
) {
122 r
= SSL_shutdown(stream
->dnstls_data
.ssl
);
124 stream
->dnstls_events
= 0;
126 r
= dnstls_flush_write_buffer(stream
);
132 error
= SSL_get_error(stream
->dnstls_data
.ssl
, r
);
133 if (IN_SET(error
, SSL_ERROR_WANT_READ
, SSL_ERROR_WANT_WRITE
)) {
134 stream
->dnstls_events
= error
== SSL_ERROR_WANT_READ
? EPOLLIN
: EPOLLOUT
;
136 r
= dnstls_flush_write_buffer(stream
);
141 } else if (error
== SSL_ERROR_SYSCALL
) {
143 log_debug_errno(errno
, "Failed to invoke SSL_shutdown, ignoring: %m");
147 ERR_error_string_n(error
, errbuf
, sizeof(errbuf
));
148 log_debug("Failed to invoke SSL_shutdown, ignoring: %s", errbuf
);
152 stream
->dnstls_events
= 0;
153 stream
->dnstls_data
.shutdown
= false;
155 r
= dnstls_flush_write_buffer(stream
);
159 dns_stream_unref(stream
);
160 return DNSTLS_STREAM_CLOSED
;
161 } else if (stream
->dnstls_data
.handshake
<= 0) {
163 stream
->dnstls_data
.handshake
= SSL_do_handshake(stream
->dnstls_data
.ssl
);
164 if (stream
->dnstls_data
.handshake
<= 0) {
165 error
= SSL_get_error(stream
->dnstls_data
.ssl
, stream
->dnstls_data
.handshake
);
166 if (IN_SET(error
, SSL_ERROR_WANT_READ
, SSL_ERROR_WANT_WRITE
)) {
167 stream
->dnstls_events
= error
== SSL_ERROR_WANT_READ
? EPOLLIN
: EPOLLOUT
;
168 r
= dnstls_flush_write_buffer(stream
);
176 ERR_error_string_n(error
, errbuf
, sizeof(errbuf
));
177 return log_debug_errno(SYNTHETIC_ERRNO(ECONNREFUSED
),
178 "Failed to invoke SSL_do_handshake: %s",
183 stream
->dnstls_events
= 0;
184 r
= dnstls_flush_write_buffer(stream
);
192 int dnstls_stream_shutdown(DnsStream
*stream
, int error
) {
197 assert(stream
->encrypted
);
198 assert(stream
->dnstls_data
.ssl
);
200 if (stream
->server
) {
201 s
= SSL_get1_session(stream
->dnstls_data
.ssl
);
203 if (stream
->server
->dnstls_data
.session
)
204 SSL_SESSION_free(stream
->server
->dnstls_data
.session
);
206 stream
->server
->dnstls_data
.session
= s
;
210 if (error
== ETIMEDOUT
) {
212 r
= SSL_shutdown(stream
->dnstls_data
.ssl
);
214 if (!stream
->dnstls_data
.shutdown
) {
215 stream
->dnstls_data
.shutdown
= true;
216 dns_stream_ref(stream
);
219 stream
->dnstls_events
= 0;
221 r
= dnstls_flush_write_buffer(stream
);
227 ssl_error
= SSL_get_error(stream
->dnstls_data
.ssl
, r
);
228 if (IN_SET(ssl_error
, SSL_ERROR_WANT_READ
, SSL_ERROR_WANT_WRITE
)) {
229 stream
->dnstls_events
= ssl_error
== SSL_ERROR_WANT_READ
? EPOLLIN
: EPOLLOUT
;
230 r
= dnstls_flush_write_buffer(stream
);
231 if (r
< 0 && r
!= -EAGAIN
)
234 if (!stream
->dnstls_data
.shutdown
) {
235 stream
->dnstls_data
.shutdown
= true;
236 dns_stream_ref(stream
);
239 } else if (ssl_error
== SSL_ERROR_SYSCALL
) {
241 log_debug_errno(errno
, "Failed to invoke SSL_shutdown, ignoring: %m");
245 ERR_error_string_n(ssl_error
, errbuf
, sizeof(errbuf
));
246 log_debug("Failed to invoke SSL_shutdown, ignoring: %s", errbuf
);
250 stream
->dnstls_events
= 0;
251 r
= dnstls_flush_write_buffer(stream
);
259 ssize_t
dnstls_stream_write(DnsStream
*stream
, const char *buf
, size_t count
) {
264 assert(stream
->encrypted
);
265 assert(stream
->dnstls_data
.ssl
);
269 ss
= r
= SSL_write(stream
->dnstls_data
.ssl
, buf
, count
);
271 error
= SSL_get_error(stream
->dnstls_data
.ssl
, r
);
272 if (IN_SET(error
, SSL_ERROR_WANT_READ
, SSL_ERROR_WANT_WRITE
)) {
273 stream
->dnstls_events
= error
== SSL_ERROR_WANT_READ
? EPOLLIN
: EPOLLOUT
;
275 } else if (error
== SSL_ERROR_ZERO_RETURN
) {
276 stream
->dnstls_events
= 0;
281 ERR_error_string_n(error
, errbuf
, sizeof(errbuf
));
282 log_debug("Failed to invoke SSL_write: %s", errbuf
);
283 stream
->dnstls_events
= 0;
287 stream
->dnstls_events
= 0;
289 r
= dnstls_flush_write_buffer(stream
);
296 ssize_t
dnstls_stream_read(DnsStream
*stream
, void *buf
, size_t count
) {
301 assert(stream
->encrypted
);
302 assert(stream
->dnstls_data
.ssl
);
306 ss
= r
= SSL_read(stream
->dnstls_data
.ssl
, buf
, count
);
308 error
= SSL_get_error(stream
->dnstls_data
.ssl
, r
);
309 if (IN_SET(error
, SSL_ERROR_WANT_READ
, SSL_ERROR_WANT_WRITE
)) {
310 stream
->dnstls_events
= error
== SSL_ERROR_WANT_READ
? EPOLLIN
: EPOLLOUT
;
312 } else if (error
== SSL_ERROR_ZERO_RETURN
) {
313 stream
->dnstls_events
= 0;
318 ERR_error_string_n(error
, errbuf
, sizeof(errbuf
));
319 log_debug("Failed to invoke SSL_read: %s", errbuf
);
320 stream
->dnstls_events
= 0;
324 stream
->dnstls_events
= 0;
326 /* flush write buffer in cache of renegotiation */
327 r
= dnstls_flush_write_buffer(stream
);
334 void dnstls_server_init(DnsServer
*server
) {
337 server
->dnstls_data
.ctx
= SSL_CTX_new(TLS_client_method());
338 if (server
->dnstls_data
.ctx
) {
339 SSL_CTX_set_min_proto_version(server
->dnstls_data
.ctx
, TLS1_2_VERSION
);
340 SSL_CTX_set_options(server
->dnstls_data
.ctx
, SSL_OP_NO_COMPRESSION
);
344 void dnstls_server_free(DnsServer
*server
) {
347 if (server
->dnstls_data
.ctx
)
348 SSL_CTX_free(server
->dnstls_data
.ctx
);
350 if (server
->dnstls_data
.session
)
351 SSL_SESSION_free(server
->dnstls_data
.session
);