resolved: add missing error code check when initializing DNS-over-TLS
[thirdparty/systemd.git] / src / resolve / resolved-dnstls-gnutls.c
1 /* SPDX-License-Identifier: LGPL-2.1+ */
2
3 #if !ENABLE_DNS_OVER_TLS || !DNS_OVER_TLS_USE_GNUTLS
4 #error This source file requires DNS-over-TLS to be enabled and GnuTLS to be available.
5 #endif
6
7 #include <gnutls/socket.h>
8
9 #include "resolved-dns-stream.h"
10 #include "resolved-dnstls.h"
11
12 DEFINE_TRIVIAL_CLEANUP_FUNC(gnutls_session_t, gnutls_deinit);
13
14 static ssize_t dnstls_stream_writev(gnutls_transport_ptr_t p, const giovec_t *iov, int iovcnt) {
15         int r;
16
17         assert(p);
18
19         r = dns_stream_writev((DnsStream*) p, (const struct iovec*) iov, iovcnt, DNS_STREAM_WRITE_TLS_DATA);
20         if (r < 0) {
21                 errno = -r;
22                 return -1;
23         }
24
25         return r;
26 }
27
28 int dnstls_stream_connect_tls(DnsStream *stream, DnsServer *server) {
29         _cleanup_(gnutls_deinitp) gnutls_session_t gs;
30         int r;
31
32         assert(stream);
33         assert(server);
34
35         r = gnutls_init(&gs, GNUTLS_CLIENT | GNUTLS_ENABLE_FALSE_START | GNUTLS_NONBLOCK);
36         if (r < 0)
37                 return r;
38
39         /* As DNS-over-TLS is a recent protocol, older TLS versions can be disabled */
40         r = gnutls_priority_set_direct(gs, "NORMAL:-VERS-ALL:+VERS-TLS1.2", NULL);
41         if (r < 0)
42                 return r;
43
44         r = gnutls_credentials_set(gs, GNUTLS_CRD_CERTIFICATE, stream->manager->dnstls_data.cert_cred);
45         if (r < 0)
46                 return r;
47
48         if (server->dnstls_data.session_data.size > 0) {
49                 gnutls_session_set_data(gs, server->dnstls_data.session_data.data, server->dnstls_data.session_data.size);
50
51                 // Clear old session ticket
52                 gnutls_free(server->dnstls_data.session_data.data);
53                 server->dnstls_data.session_data.data = NULL;
54                 server->dnstls_data.session_data.size = 0;
55         }
56
57         gnutls_handshake_set_timeout(gs, GNUTLS_DEFAULT_HANDSHAKE_TIMEOUT);
58
59         gnutls_transport_set_ptr2(gs, (gnutls_transport_ptr_t) (long) stream->fd, stream);
60         gnutls_transport_set_vec_push_function(gs, &dnstls_stream_writev);
61
62         stream->encrypted = true;
63         stream->dnstls_data.handshake = gnutls_handshake(gs);
64         if (stream->dnstls_data.handshake < 0 && gnutls_error_is_fatal(stream->dnstls_data.handshake))
65                 return -ECONNREFUSED;
66
67         stream->dnstls_data.session = TAKE_PTR(gs);
68
69         return 0;
70 }
71
72 void dnstls_stream_free(DnsStream *stream) {
73         assert(stream);
74         assert(stream->encrypted);
75
76         if (stream->dnstls_data.session)
77                 gnutls_deinit(stream->dnstls_data.session);
78 }
79
80 int dnstls_stream_on_io(DnsStream *stream, uint32_t revents) {
81         int r;
82
83         assert(stream);
84         assert(stream->encrypted);
85         assert(stream->dnstls_data.session);
86
87         if (stream->dnstls_data.shutdown) {
88                 r = gnutls_bye(stream->dnstls_data.session, GNUTLS_SHUT_RDWR);
89                 if (r == GNUTLS_E_AGAIN) {
90                         stream->dnstls_events = gnutls_record_get_direction(stream->dnstls_data.session) == 1 ? EPOLLOUT : EPOLLIN;
91                         return -EAGAIN;
92                 } else if (r < 0)
93                         log_debug("Failed to invoke gnutls_bye: %s", gnutls_strerror(r));
94
95                 stream->dnstls_events = 0;
96                 stream->dnstls_data.shutdown = false;
97                 dns_stream_unref(stream);
98                 return DNSTLS_STREAM_CLOSED;
99         } else if (stream->dnstls_data.handshake < 0) {
100                 stream->dnstls_data.handshake = gnutls_handshake(stream->dnstls_data.session);
101                 if (stream->dnstls_data.handshake == GNUTLS_E_AGAIN) {
102                         stream->dnstls_events = gnutls_record_get_direction(stream->dnstls_data.session) == 1 ? EPOLLOUT : EPOLLIN;
103                         return -EAGAIN;
104                 } else if (stream->dnstls_data.handshake < 0) {
105                         log_debug("Failed to invoke gnutls_handshake: %s", gnutls_strerror(stream->dnstls_data.handshake));
106                         if (gnutls_error_is_fatal(stream->dnstls_data.handshake))
107                                 return -ECONNREFUSED;
108                 }
109
110                 stream->dnstls_events = 0;
111         }
112
113         return 0;
114 }
115
116 int dnstls_stream_shutdown(DnsStream *stream, int error) {
117         int r;
118
119         assert(stream);
120         assert(stream->encrypted);
121         assert(stream->dnstls_data.session);
122
123         /* Store TLS Ticket for faster successive TLS handshakes */
124         if (stream->server && stream->server->dnstls_data.session_data.size == 0 && stream->dnstls_data.handshake == GNUTLS_E_SUCCESS)
125                 gnutls_session_get_data2(stream->dnstls_data.session, &stream->server->dnstls_data.session_data);
126
127         if (IN_SET(error, ETIMEDOUT, 0)) {
128                 r = gnutls_bye(stream->dnstls_data.session, GNUTLS_SHUT_RDWR);
129                 if (r == GNUTLS_E_AGAIN) {
130                         if (!stream->dnstls_data.shutdown) {
131                                 stream->dnstls_data.shutdown = true;
132                                 dns_stream_ref(stream);
133                                 return -EAGAIN;
134                         }
135                 } else if (r < 0)
136                         log_debug("Failed to invoke gnutls_bye: %s", gnutls_strerror(r));
137         }
138
139         return 0;
140 }
141
142 ssize_t dnstls_stream_write(DnsStream *stream, const char *buf, size_t count) {
143         ssize_t ss;
144
145         assert(stream);
146         assert(stream->encrypted);
147         assert(stream->dnstls_data.session);
148         assert(buf);
149
150         ss = gnutls_record_send(stream->dnstls_data.session, buf, count);
151         if (ss < 0)
152                 switch(ss) {
153                 case GNUTLS_E_INTERRUPTED:
154                         return -EINTR;
155                 case GNUTLS_E_AGAIN:
156                         return -EAGAIN;
157                 default:
158                         return log_debug_errno(SYNTHETIC_ERRNO(EPIPE),
159                                                "Failed to invoke gnutls_record_send: %s",
160                                                gnutls_strerror(ss));
161                 }
162
163         return ss;
164 }
165
166 ssize_t dnstls_stream_read(DnsStream *stream, void *buf, size_t count) {
167         ssize_t ss;
168
169         assert(stream);
170         assert(stream->encrypted);
171         assert(stream->dnstls_data.session);
172         assert(buf);
173
174         ss = gnutls_record_recv(stream->dnstls_data.session, buf, count);
175         if (ss < 0)
176                 switch(ss) {
177                 case GNUTLS_E_INTERRUPTED:
178                         return -EINTR;
179                 case GNUTLS_E_AGAIN:
180                         return -EAGAIN;
181                 default:
182                         return log_debug_errno(SYNTHETIC_ERRNO(EPIPE),
183                                                "Failed to invoke gnutls_record_recv: %s",
184                                                gnutls_strerror(ss));
185                 }
186
187         return ss;
188 }
189
190 void dnstls_server_free(DnsServer *server) {
191         assert(server);
192
193         if (server->dnstls_data.session_data.data)
194                 gnutls_free(server->dnstls_data.session_data.data);
195 }
196
197 int dnstls_manager_init(Manager *manager) {
198         int r;
199         assert(manager);
200
201         r = gnutls_certificate_allocate_credentials(&manager->dnstls_data.cert_cred);
202         if (r < 0)
203                 return -ENOMEM;
204
205         return 0;
206 }
207
208 void dnstls_manager_free(Manager *manager) {
209         assert(manager);
210
211         if (manager->dnstls_data.cert_cred)
212                 gnutls_certificate_free_credentials(manager->dnstls_data.cert_cred);
213 }