From: W.C.A. Wijngaards Date: Fri, 26 Jun 2020 14:05:15 +0000 (+0200) Subject: tcp read and write handling of write events in netevent for tcp and ssl. X-Git-Tag: release-1.13.0rc1~5^2~53 X-Git-Url: http://git.ipfire.org/gitweb.cgi?a=commitdiff_plain;h=cfe009a31c7b092e236da3e02517b1eb17c954ca;p=thirdparty%2Funbound.git tcp read and write handling of write events in netevent for tcp and ssl. --- diff --git a/util/netevent.c b/util/netevent.c index bcff7a590..6289df823 100644 --- a/util/netevent.c +++ b/util/netevent.c @@ -992,11 +992,12 @@ static void tcp_callback_writer(struct comm_point* c) { log_assert(c->type == comm_tcp); - sldns_buffer_clear(c->buffer); + if(!c->tcp_write_and_read) { + sldns_buffer_clear(c->buffer); + c->tcp_byte_count = 0; + } if(c->tcp_do_toggle_rw) c->tcp_is_reading = 1; - if(!c->tcp_write_and_read) - c->tcp_byte_count = 0; /* switch from listening(write) to listening(read) */ if(c->tcp_req_info) { tcp_req_info_handle_writedone(c->tcp_req_info); @@ -1302,10 +1303,28 @@ ssl_handle_write(struct comm_point* c) } /* ignore return, if fails we may simply block */ (void)SSL_set_mode(c->ssl, (long)SSL_MODE_ENABLE_PARTIAL_WRITE); - if(c->tcp_byte_count < sizeof(uint16_t)) { - uint16_t len = htons(sldns_buffer_limit(c->buffer)); + if((c->tcp_write_and_read?c->tcp_write_byte_count:c->tcp_byte_count) < sizeof(uint16_t)) { + uint16_t len = htons(c->tcp_write_and_read?c->tcp_write_pkt_len:sldns_buffer_limit(c->buffer)); ERR_clear_error(); - if(sizeof(uint16_t)+sldns_buffer_remaining(c->buffer) < + if(c->tcp_write_and_read) { + if(c->tcp_write_pkt_len + 2 < LDNS_RR_BUF_SIZE) { + /* combine the tcp length and the query for + * write, this emulates writev */ + uint8_t buf[LDNS_RR_BUF_SIZE]; + memmove(buf, &len, sizeof(uint16_t)); + memmove(buf+sizeof(uint16_t), + c->tcp_write_pkt, + c->tcp_write_pkt_len); + r = SSL_write(c->ssl, + (void*)(buf+c->tcp_write_byte_count), + c->tcp_write_pkt_len + 2 - + c->tcp_write_byte_count); + } else { + r = SSL_write(c->ssl, + (void*)(((uint8_t*)&len)+c->tcp_write_byte_count), + (int)(sizeof(uint16_t)-c->tcp_write_byte_count)); + } + } else if(sizeof(uint16_t)+sldns_buffer_remaining(c->buffer) < LDNS_RR_BUF_SIZE) { /* combine the tcp length and the query for write, * this emulates writev */ @@ -1347,20 +1366,32 @@ ssl_handle_write(struct comm_point* c) log_crypto_err("could not SSL_write"); return 0; } - c->tcp_byte_count += r; - if(c->tcp_byte_count < sizeof(uint16_t)) - return 1; - sldns_buffer_set_position(c->buffer, c->tcp_byte_count - - sizeof(uint16_t)); - if(sldns_buffer_remaining(c->buffer) == 0) { + if(c->tcp_write_and_read) { + c->tcp_write_byte_count += r; + if(c->tcp_write_byte_count < sizeof(uint16_t)) + return 1; + } else { + c->tcp_byte_count += r; + if(c->tcp_byte_count < sizeof(uint16_t)) + return 1; + sldns_buffer_set_position(c->buffer, c->tcp_byte_count - + sizeof(uint16_t)); + } + if((!c->tcp_write_and_read && sldns_buffer_remaining(c->buffer) == 0) || (c->tcp_write_and_read && c->tcp_write_byte_count == c->tcp_write_pkt_len + 2)) { tcp_callback_writer(c); return 1; } } - log_assert(sldns_buffer_remaining(c->buffer) > 0); + log_assert(c->tcp_write_and_read || sldns_buffer_remaining(c->buffer) > 0); + log_assert(!c->tcp_write_and_read || c->tcp_write_byte_count < c->tcp_write_pkt_len + 2); ERR_clear_error(); - r = SSL_write(c->ssl, (void*)sldns_buffer_current(c->buffer), - (int)sldns_buffer_remaining(c->buffer)); + if(c->tcp_write_and_read) { + r = SSL_write(c->ssl, (void*)(c->tcp_write_pkt + c->tcp_write_byte_count - 2), + (int)(c->tcp_write_pkt_len + 2 - c->tcp_write_byte_count)); + } else { + r = SSL_write(c->ssl, (void*)sldns_buffer_current(c->buffer), + (int)sldns_buffer_remaining(c->buffer)); + } if(r <= 0) { int want = SSL_get_error(c->ssl, r); if(want == SSL_ERROR_ZERO_RETURN) { @@ -1385,9 +1416,13 @@ ssl_handle_write(struct comm_point* c) log_crypto_err("could not SSL_write"); return 0; } - sldns_buffer_skip(c->buffer, (ssize_t)r); + if(c->tcp_write_and_read) { + c->tcp_write_byte_count += r; + } else { + sldns_buffer_skip(c->buffer, (ssize_t)r); + } - if(sldns_buffer_remaining(c->buffer) == 0) { + if((!c->tcp_write_and_read && sldns_buffer_remaining(c->buffer) == 0) || (c->tcp_write_and_read && c->tcp_write_byte_count == c->tcp_write_pkt_len + 2)) { tcp_callback_writer(c); } return 1; @@ -1531,7 +1566,7 @@ comm_point_tcp_handle_write(int fd, struct comm_point* c) if(c->tcp_is_reading && !c->ssl) return 0; log_assert(fd != -1); - if(c->tcp_byte_count == 0 && c->tcp_check_nb_connect) { + if(((!c->tcp_write_and_read && c->tcp_byte_count == 0) || (c->tcp_write_and_read && c->tcp_write_byte_count == 0)) && c->tcp_check_nb_connect) { /* check for pending error from nonblocking connect */ /* from Stevens, unix network programming, vol1, 3rd ed, p450*/ int error = 0; @@ -1581,15 +1616,22 @@ comm_point_tcp_handle_write(int fd, struct comm_point* c) if(c->tcp_do_fastopen == 1) { /* this form of sendmsg() does both a connect() and send() so need to look for various flavours of error*/ - uint16_t len = htons(sldns_buffer_limit(buffer)); + uint16_t len = htons(c->tcp_write_and_read?c->tcp_write_pkt_len:sldns_buffer_limit(buffer)); struct msghdr msg; struct iovec iov[2]; c->tcp_do_fastopen = 0; memset(&msg, 0, sizeof(msg)); - iov[0].iov_base = (uint8_t*)&len + c->tcp_byte_count; - iov[0].iov_len = sizeof(uint16_t) - c->tcp_byte_count; - iov[1].iov_base = sldns_buffer_begin(buffer); - iov[1].iov_len = sldns_buffer_limit(buffer); + if(c->tcp_write_and_read) { + iov[0].iov_base = (uint8_t*)&len + c->tcp_write_byte_count; + iov[0].iov_len = sizeof(uint16_t) - c->tcp_write_byte_count; + iov[1].iov_base = c->tcp_write_pkt; + iov[1].iov_len = c->tcp_write_pkt_len; + } else { + iov[0].iov_base = (uint8_t*)&len + c->tcp_byte_count; + iov[0].iov_len = sizeof(uint16_t) - c->tcp_byte_count; + iov[1].iov_base = sldns_buffer_begin(buffer); + iov[1].iov_len = sldns_buffer_limit(buffer); + } log_assert(iov[0].iov_len > 0); msg.msg_name = &c->repinfo.addr; msg.msg_namelen = c->repinfo.addrlen; @@ -1635,12 +1677,18 @@ comm_point_tcp_handle_write(int fd, struct comm_point* c) } } else { - c->tcp_byte_count += r; - if(c->tcp_byte_count < sizeof(uint16_t)) - return 1; - sldns_buffer_set_position(buffer, c->tcp_byte_count - - sizeof(uint16_t)); - if(sldns_buffer_remaining(buffer) == 0) { + if(c->tcp_write_and_read) { + c->tcp_write_byte_count += r; + if(c->tcp_write_byte_count < sizeof(uint16_t)) + return 1; + } else { + c->tcp_byte_count += r; + if(c->tcp_byte_count < sizeof(uint16_t)) + return 1; + sldns_buffer_set_position(buffer, c->tcp_byte_count - + sizeof(uint16_t)); + } + if((!c->tcp_write_and_read && sldns_buffer_remaining(buffer) == 0) || (c->tcp_write_and_read && c->tcp_write_byte_count == c->tcp_write_pkt_len + 2)) { tcp_callback_writer(c); return 1; } @@ -1648,19 +1696,31 @@ comm_point_tcp_handle_write(int fd, struct comm_point* c) } #endif /* USE_MSG_FASTOPEN */ - if(c->tcp_byte_count < sizeof(uint16_t)) { - uint16_t len = htons(sldns_buffer_limit(buffer)); + if((c->tcp_write_and_read?c->tcp_write_byte_count:c->tcp_byte_count) < sizeof(uint16_t)) { + uint16_t len = htons(c->tcp_write_and_read?c->tcp_write_pkt_len:sldns_buffer_limit(buffer)); #ifdef HAVE_WRITEV struct iovec iov[2]; - iov[0].iov_base = (uint8_t*)&len + c->tcp_byte_count; - iov[0].iov_len = sizeof(uint16_t) - c->tcp_byte_count; - iov[1].iov_base = sldns_buffer_begin(buffer); - iov[1].iov_len = sldns_buffer_limit(buffer); + if(c->tcp_write_and_read) { + iov[0].iov_base = (uint8_t*)&len + c->tcp_write_byte_count; + iov[0].iov_len = sizeof(uint16_t) - c->tcp_write_byte_count; + iov[1].iov_base = c->tcp_write_pkt; + iov[1].iov_len = c->tcp_write_pkt_len; + } else { + iov[0].iov_base = (uint8_t*)&len + c->tcp_byte_count; + iov[0].iov_len = sizeof(uint16_t) - c->tcp_byte_count; + iov[1].iov_base = sldns_buffer_begin(buffer); + iov[1].iov_len = sldns_buffer_limit(buffer); + } log_assert(iov[0].iov_len > 0); r = writev(fd, iov, 2); #else /* HAVE_WRITEV */ - r = send(fd, (void*)(((uint8_t*)&len)+c->tcp_byte_count), - sizeof(uint16_t)-c->tcp_byte_count, 0); + if(c->tcp_write_and_read) { + r = send(fd, (void*)(((uint8_t*)&len)+c->tcp_write_byte_count), + sizeof(uint16_t)-c->tcp_write_byte_count, 0); + } else { + r = send(fd, (void*)(((uint8_t*)&len)+c->tcp_byte_count), + sizeof(uint16_t)-c->tcp_byte_count, 0); + } #endif /* HAVE_WRITEV */ if(r == -1) { #ifndef USE_WINSOCK @@ -1699,19 +1759,31 @@ comm_point_tcp_handle_write(int fd, struct comm_point* c) #endif return 0; } - c->tcp_byte_count += r; - if(c->tcp_byte_count < sizeof(uint16_t)) - return 1; - sldns_buffer_set_position(buffer, c->tcp_byte_count - - sizeof(uint16_t)); - if(sldns_buffer_remaining(buffer) == 0) { + if(c->tcp_write_and_read) { + c->tcp_write_byte_count += r; + if(c->tcp_write_byte_count < sizeof(uint16_t)) + return 1; + } else { + c->tcp_byte_count += r; + if(c->tcp_byte_count < sizeof(uint16_t)) + return 1; + sldns_buffer_set_position(buffer, c->tcp_byte_count - + sizeof(uint16_t)); + } + if((!c->tcp_write_and_read && sldns_buffer_remaining(buffer) == 0) || (c->tcp_write_and_read && c->tcp_write_byte_count == c->tcp_write_pkt_len + 2)) { tcp_callback_writer(c); return 1; } } - log_assert(sldns_buffer_remaining(buffer) > 0); - r = send(fd, (void*)sldns_buffer_current(buffer), - sldns_buffer_remaining(buffer), 0); + log_assert(c->tcp_write_and_read || sldns_buffer_remaining(buffer) > 0); + log_assert(!c->tcp_write_and_read || c->tcp_write_byte_count < c->tcp_write_pkt_len + 2); + if(c->tcp_write_and_read) { + r = send(fd, (void*)c->tcp_write_pkt + c->tcp_write_byte_count - 2, + c->tcp_write_pkt_len + 2 - c->tcp_write_byte_count, 0); + } else { + r = send(fd, (void*)sldns_buffer_current(buffer), + sldns_buffer_remaining(buffer), 0); + } if(r == -1) { #ifndef USE_WINSOCK if(errno == EINTR || errno == EAGAIN) @@ -1736,9 +1808,13 @@ comm_point_tcp_handle_write(int fd, struct comm_point* c) #endif return 0; } - sldns_buffer_skip(buffer, r); + if(c->tcp_write_and_read) { + c->tcp_write_byte_count += r; + } else { + sldns_buffer_skip(buffer, r); + } - if(sldns_buffer_remaining(buffer) == 0) { + if((!c->tcp_write_and_read && sldns_buffer_remaining(buffer) == 0) || (c->tcp_write_and_read && c->tcp_write_byte_count == c->tcp_write_pkt_len + 2)) { tcp_callback_writer(c); } diff --git a/util/netevent.h b/util/netevent.h index c044f8938..300592e5b 100644 --- a/util/netevent.h +++ b/util/netevent.h @@ -254,11 +254,16 @@ struct comm_point { int tcp_write_and_read; /** byte count for written length over write channel, for when - * tcp_write_and_read is enabled */ + * tcp_write_and_read is enabled. When tcp_write_and_read is enabled, + * this is the counter for writing, the one for reading is in the + * commpoint.buffer sldns buffer. The counter counts from 0 to + * 2+tcp_write_pkt_len, and includes the tcp length bytes. */ size_t tcp_write_byte_count; /** packet to write currently over the write channel. for when - * tcp_write_and_read is enabled */ + * tcp_write_and_read is enabled. When tcp_write_and_read is enabled, + * this is the buffer for the written packet, the commpoint.buffer + * sldns buffer is the buffer for the received packet. */ uint8_t* tcp_write_pkt; /** length of tcp_write_pkt in bytes */ size_t tcp_write_pkt_len;