]> git.ipfire.org Git - thirdparty/systemd.git/commitdiff
resolved: Avoid multiple SSL writes per DoT packet
authorJoan Bruguera <joanbrugueram@gmail.com>
Mon, 31 Jan 2022 20:28:32 +0000 (21:28 +0100)
committerJoan Bruguera <joanbrugueram@gmail.com>
Tue, 1 Feb 2022 18:24:40 +0000 (19:24 +0100)
In the DoT case, dns_stream_writev decomposed an iovec into multiple
dnstls_stream_write calls, which resulted in multiple SSL writes and multiple
TLS records. This can be checked from a network capture, e.g. using socat:
socat -v -x openssl-listen:853,reuseaddr,fork,cert=my.cert,key=my.key,verify=0 openssl:8.8.8.8:853

Instead, propagate the iovec as-is into the DoT handling code. For GnuTLS, the
library provides support for buffering ('corking') a record. OpenSSL has no
such facility, so we join the iovec into a single buffer then call SSL_write.

socat capture of `resolvectl -4 query --cache=no example.com` before the commit:

> 2022/01/30 13:35:52.194200  length=2 from=0 to=1
 00 28                                            .(
--
> 2022/01/30 13:35:52.194253  length=40 from=2 to=41
 1e b2 01 00 00 01 00 00 00 00 00 01 07 65 78 61  .............exa
 6d 70 6c 65 03 63 6f 6d 00 00 01 00 01 00 00 29  mple.com.......)
 ff e4 00 00 00 00 00 00                          ........
--
< 2022/01/30 13:35:52.232798  length=58 from=0 to=57
 00 38 1e b2 81 80 00 01 00 01 00 00 00 01 07 65  .8.............e
 78 61 6d 70 6c 65 03 63 6f 6d 00 00 01 00 01 c0  xample.com......
 0c 00 01 00 01 00 00 53 6f 00 04 5d b8 d8 22 00  .......So..]..".
 00 29 02 00 00 00 00 00 00 00                    .)........

socat capture of `resolvectl -4 query --cache=no example.com` after the commit:

> 2022/01/30 13:34:47.598099  length=42 from=504 to=545
 00 28 37 86 01 00 00 01 00 00 00 00 00 01 07 65  .(7............e
 78 61 6d 70 6c 65 03 63 6f 6d 00 00 01 00 01 00  xample.com......
 00 29 ff e4 00 00 00 00 00 00                    .)........
--
< 2022/01/30 13:34:47.613203  length=58 from=756 to=813
 00 38 37 86 81 80 00 01 00 01 00 00 00 01 07 65  .87............e
 78 61 6d 70 6c 65 03 63 6f 6d 00 00 01 00 01 c0  xample.com......
 0c 00 01 00 01 00 00 52 5e 00 04 5d b8 d8 22 00  .......R^..]..".
 00 29 02 00 00 00 00 00 00 00                    .)........

src/resolve/resolved-dns-stream.c
src/resolve/resolved-dnstls-gnutls.c
src/resolve/resolved-dnstls-openssl.c
src/resolve/resolved-dnstls.h

index 290c28ed652c251c95aef91f30c0144e05ca69ab..5c4a9ebb9990ee63823b035f8583fd8938d3f4b3 100644 (file)
@@ -210,22 +210,10 @@ ssize_t dns_stream_writev(DnsStream *s, const struct iovec *iov, size_t iovcnt,
         assert(iov);
 
 #if ENABLE_DNS_OVER_TLS
-        if (s->encrypted && !(flags & DNS_STREAM_WRITE_TLS_DATA)) {
-                ssize_t ss;
-                size_t i;
-
-                m = 0;
-                for (i = 0; i < iovcnt; i++) {
-                        ss = dnstls_stream_write(s, iov[i].iov_base, iov[i].iov_len);
-                        if (ss < 0)
-                                return ss;
-
-                        m += ss;
-                        if (ss != (ssize_t) iov[i].iov_len)
-                                continue;
-                }
-        } else
+        if (s->encrypted && !(flags & DNS_STREAM_WRITE_TLS_DATA))
+                return dnstls_stream_writev(s, iov, iovcnt);
 #endif
+
         if (s->tfo_salen > 0) {
                 struct msghdr hdr = {
                         .msg_iov = (struct iovec*) iov,
index 8610cacab67cad01c1177432a807854b5444899a..3d361708a10fe0682d4d458052d27a6dece151ea 100644 (file)
@@ -6,6 +6,7 @@
 
 #include <gnutls/socket.h>
 
+#include "io-util.h"
 #include "resolved-dns-stream.h"
 #include "resolved-dnstls.h"
 #include "resolved-manager.h"
@@ -13,7 +14,7 @@
 #define TLS_PROTOCOL_PRIORITY "NORMAL:-VERS-ALL:+VERS-TLS1.3:+VERS-TLS1.2"
 DEFINE_TRIVIAL_CLEANUP_FUNC_FULL(gnutls_session_t, gnutls_deinit, NULL);
 
-static ssize_t dnstls_stream_writev(gnutls_transport_ptr_t p, const giovec_t *iov, int iovcnt) {
+static ssize_t dnstls_stream_vec_push(gnutls_transport_ptr_t p, const giovec_t *iov, int iovcnt) {
         int r;
 
         assert(p);
@@ -81,7 +82,7 @@ int dnstls_stream_connect_tls(DnsStream *stream, DnsServer *server) {
         gnutls_handshake_set_timeout(gs, GNUTLS_DEFAULT_HANDSHAKE_TIMEOUT);
 
         gnutls_transport_set_ptr2(gs, (gnutls_transport_ptr_t) (long) stream->fd, stream);
-        gnutls_transport_set_vec_push_function(gs, &dnstls_stream_writev);
+        gnutls_transport_set_vec_push_function(gs, &dnstls_stream_vec_push);
 
         stream->encrypted = true;
         stream->dnstls_data.handshake = gnutls_handshake(gs);
@@ -163,15 +164,26 @@ int dnstls_stream_shutdown(DnsStream *stream, int error) {
         return 0;
 }
 
-ssize_t dnstls_stream_write(DnsStream *stream, const char *buf, size_t count) {
+ssize_t dnstls_stream_writev(DnsStream *stream, const struct iovec *iov, size_t iovcnt) {
         ssize_t ss;
 
         assert(stream);
         assert(stream->encrypted);
         assert(stream->dnstls_data.session);
-        assert(buf);
+        assert(iov);
+        assert(IOVEC_TOTAL_SIZE(iov, iovcnt) > 0);
+
+        gnutls_record_cork(stream->dnstls_data.session);
+
+        for (size_t i = 0; i < iovcnt; i++) {
+                ss = gnutls_record_send(
+                        stream->dnstls_data.session,
+                        iov[i].iov_base, iov[i].iov_len);
+                if (ss < 0)
+                        break;
+        }
 
-        ss = gnutls_record_send(stream->dnstls_data.session, buf, count);
+        ss = gnutls_record_uncork(stream->dnstls_data.session, 0);
         if (ss < 0)
                 switch(ss) {
                 case GNUTLS_E_INTERRUPTED:
index 7d264dd367365272f72d0920acd4dc0ff7b28ed7..3a030048625615118da45674317db2e7404acd61 100644 (file)
@@ -292,15 +292,10 @@ int dnstls_stream_shutdown(DnsStream *stream, int error) {
         return 0;
 }
 
-ssize_t dnstls_stream_write(DnsStream *stream, const char *buf, size_t count) {
+static ssize_t dnstls_stream_write(DnsStream *stream, const char *buf, size_t count) {
         int error, r;
         ssize_t ss;
 
-        assert(stream);
-        assert(stream->encrypted);
-        assert(stream->dnstls_data.ssl);
-        assert(buf);
-
         ERR_clear_error();
         ss = r = SSL_write(stream->dnstls_data.ssl, buf, count);
         if (r <= 0) {
@@ -329,6 +324,29 @@ ssize_t dnstls_stream_write(DnsStream *stream, const char *buf, size_t count) {
         return ss;
 }
 
+ssize_t dnstls_stream_writev(DnsStream *stream, const struct iovec *iov, size_t iovcnt) {
+        _cleanup_free_ char *buf = NULL;
+        size_t count;
+
+        assert(stream);
+        assert(stream->encrypted);
+        assert(stream->dnstls_data.ssl);
+        assert(iov);
+        assert(IOVEC_TOTAL_SIZE(iov, iovcnt) > 0);
+
+        if (iovcnt == 1)
+                return dnstls_stream_write(stream, iov[0].iov_base, iov[0].iov_len);
+
+        /* As of now, OpenSSL can not accumulate multiple writes, so join into a
+           single buffer. Suboptimal, but better than multiple SSL_write calls. */
+        count = IOVEC_TOTAL_SIZE(iov, iovcnt);
+        buf = new(char, count);
+        for (size_t i = 0, pos = 0; i < iovcnt; pos += iov[i].iov_len, i++)
+                memcpy(buf + pos, iov[i].iov_base, iov[i].iov_len);
+
+        return dnstls_stream_write(stream, buf, count);
+}
+
 ssize_t dnstls_stream_read(DnsStream *stream, void *buf, size_t count) {
         int error, r;
         ssize_t ss;
index ed214dc6c46cea00197dbf513b6434597a0d07fc..70b27d8d77fcd5695bdf8d8b73ce9bb193415e2a 100644 (file)
@@ -5,6 +5,7 @@
 
 #include <stdbool.h>
 #include <stdint.h>
+#include <sys/uio.h>
 
 typedef struct DnsServer DnsServer;
 typedef struct DnsStream DnsStream;
@@ -27,7 +28,7 @@ int dnstls_stream_connect_tls(DnsStream *stream, DnsServer *server);
 void dnstls_stream_free(DnsStream *stream);
 int dnstls_stream_on_io(DnsStream *stream, uint32_t revents);
 int dnstls_stream_shutdown(DnsStream *stream, int error);
-ssize_t dnstls_stream_write(DnsStream *stream, const char *buf, size_t count);
+ssize_t dnstls_stream_writev(DnsStream *stream, const struct iovec *iov, size_t iovcnt);
 ssize_t dnstls_stream_read(DnsStream *stream, void *buf, size_t count);
 bool dnstls_stream_has_buffered_data(DnsStream *stream);