assert(iov);
#if ENABLE_DNS_OVER_TLS
- if (s->encrypted && !(flags & DNS_STREAM_WRITE_TLS_DATA)) {
- ssize_t ss;
- size_t i;
-
- m = 0;
- for (i = 0; i < iovcnt; i++) {
- ss = dnstls_stream_write(s, iov[i].iov_base, iov[i].iov_len);
- if (ss < 0)
- return ss;
-
- m += ss;
- if (ss != (ssize_t) iov[i].iov_len)
- continue;
- }
- } else
+ if (s->encrypted && !(flags & DNS_STREAM_WRITE_TLS_DATA))
+ return dnstls_stream_writev(s, iov, iovcnt);
#endif
+
if (s->tfo_salen > 0) {
struct msghdr hdr = {
.msg_iov = (struct iovec*) iov,
#include <gnutls/socket.h>
+#include "io-util.h"
#include "resolved-dns-stream.h"
#include "resolved-dnstls.h"
#include "resolved-manager.h"
#define TLS_PROTOCOL_PRIORITY "NORMAL:-VERS-ALL:+VERS-TLS1.3:+VERS-TLS1.2"
DEFINE_TRIVIAL_CLEANUP_FUNC_FULL(gnutls_session_t, gnutls_deinit, NULL);
-static ssize_t dnstls_stream_writev(gnutls_transport_ptr_t p, const giovec_t *iov, int iovcnt) {
+static ssize_t dnstls_stream_vec_push(gnutls_transport_ptr_t p, const giovec_t *iov, int iovcnt) {
int r;
assert(p);
gnutls_handshake_set_timeout(gs, GNUTLS_DEFAULT_HANDSHAKE_TIMEOUT);
gnutls_transport_set_ptr2(gs, (gnutls_transport_ptr_t) (long) stream->fd, stream);
- gnutls_transport_set_vec_push_function(gs, &dnstls_stream_writev);
+ gnutls_transport_set_vec_push_function(gs, &dnstls_stream_vec_push);
stream->encrypted = true;
stream->dnstls_data.handshake = gnutls_handshake(gs);
return 0;
}
-ssize_t dnstls_stream_write(DnsStream *stream, const char *buf, size_t count) {
+ssize_t dnstls_stream_writev(DnsStream *stream, const struct iovec *iov, size_t iovcnt) {
ssize_t ss;
assert(stream);
assert(stream->encrypted);
assert(stream->dnstls_data.session);
- assert(buf);
+ assert(iov);
+ assert(IOVEC_TOTAL_SIZE(iov, iovcnt) > 0);
+
+ gnutls_record_cork(stream->dnstls_data.session);
+
+ for (size_t i = 0; i < iovcnt; i++) {
+ ss = gnutls_record_send(
+ stream->dnstls_data.session,
+ iov[i].iov_base, iov[i].iov_len);
+ if (ss < 0)
+ break;
+ }
- ss = gnutls_record_send(stream->dnstls_data.session, buf, count);
+ ss = gnutls_record_uncork(stream->dnstls_data.session, 0);
if (ss < 0)
switch(ss) {
case GNUTLS_E_INTERRUPTED:
return 0;
}
-ssize_t dnstls_stream_write(DnsStream *stream, const char *buf, size_t count) {
+static ssize_t dnstls_stream_write(DnsStream *stream, const char *buf, size_t count) {
int error, r;
ssize_t ss;
- assert(stream);
- assert(stream->encrypted);
- assert(stream->dnstls_data.ssl);
- assert(buf);
-
ERR_clear_error();
ss = r = SSL_write(stream->dnstls_data.ssl, buf, count);
if (r <= 0) {
return ss;
}
+ssize_t dnstls_stream_writev(DnsStream *stream, const struct iovec *iov, size_t iovcnt) {
+ _cleanup_free_ char *buf = NULL;
+ size_t count;
+
+ assert(stream);
+ assert(stream->encrypted);
+ assert(stream->dnstls_data.ssl);
+ assert(iov);
+ assert(IOVEC_TOTAL_SIZE(iov, iovcnt) > 0);
+
+ if (iovcnt == 1)
+ return dnstls_stream_write(stream, iov[0].iov_base, iov[0].iov_len);
+
+ /* As of now, OpenSSL can not accumulate multiple writes, so join into a
+ single buffer. Suboptimal, but better than multiple SSL_write calls. */
+ count = IOVEC_TOTAL_SIZE(iov, iovcnt);
+ buf = new(char, count);
+ for (size_t i = 0, pos = 0; i < iovcnt; pos += iov[i].iov_len, i++)
+ memcpy(buf + pos, iov[i].iov_base, iov[i].iov_len);
+
+ return dnstls_stream_write(stream, buf, count);
+}
+
ssize_t dnstls_stream_read(DnsStream *stream, void *buf, size_t count) {
int error, r;
ssize_t ss;
#include <stdbool.h>
#include <stdint.h>
+#include <sys/uio.h>
typedef struct DnsServer DnsServer;
typedef struct DnsStream DnsStream;
void dnstls_stream_free(DnsStream *stream);
int dnstls_stream_on_io(DnsStream *stream, uint32_t revents);
int dnstls_stream_shutdown(DnsStream *stream, int error);
-ssize_t dnstls_stream_write(DnsStream *stream, const char *buf, size_t count);
+ssize_t dnstls_stream_writev(DnsStream *stream, const struct iovec *iov, size_t iovcnt);
ssize_t dnstls_stream_read(DnsStream *stream, void *buf, size_t count);
bool dnstls_stream_has_buffered_data(DnsStream *stream);