]> git.ipfire.org Git - thirdparty/unbound.git/commitdiff
tcp read and write handling of write events in netevent for tcp and ssl.
authorW.C.A. Wijngaards <wouter@nlnetlabs.nl>
Fri, 26 Jun 2020 14:05:15 +0000 (16:05 +0200)
committerW.C.A. Wijngaards <wouter@nlnetlabs.nl>
Fri, 26 Jun 2020 14:05:15 +0000 (16:05 +0200)
util/netevent.c
util/netevent.h

index bcff7a590a916af1fc10d012e8c030dcebff25c7..6289df82399dd11d1d4b3620f58cd2ad185fab6a 100644 (file)
@@ -992,11 +992,12 @@ static void
 tcp_callback_writer(struct comm_point* c)
 {
        log_assert(c->type == comm_tcp);
-       sldns_buffer_clear(c->buffer);
+       if(!c->tcp_write_and_read) {
+               sldns_buffer_clear(c->buffer);
+               c->tcp_byte_count = 0;
+       }
        if(c->tcp_do_toggle_rw)
                c->tcp_is_reading = 1;
-       if(!c->tcp_write_and_read)
-               c->tcp_byte_count = 0;
        /* switch from listening(write) to listening(read) */
        if(c->tcp_req_info) {
                tcp_req_info_handle_writedone(c->tcp_req_info);
@@ -1302,10 +1303,28 @@ ssl_handle_write(struct comm_point* c)
        }
        /* ignore return, if fails we may simply block */
        (void)SSL_set_mode(c->ssl, (long)SSL_MODE_ENABLE_PARTIAL_WRITE);
-       if(c->tcp_byte_count < sizeof(uint16_t)) {
-               uint16_t len = htons(sldns_buffer_limit(c->buffer));
+       if((c->tcp_write_and_read?c->tcp_write_byte_count:c->tcp_byte_count) < sizeof(uint16_t)) {
+               uint16_t len = htons(c->tcp_write_and_read?c->tcp_write_pkt_len:sldns_buffer_limit(c->buffer));
                ERR_clear_error();
-               if(sizeof(uint16_t)+sldns_buffer_remaining(c->buffer) <
+               if(c->tcp_write_and_read) {
+                       if(c->tcp_write_pkt_len + 2 < LDNS_RR_BUF_SIZE) {
+                               /* combine the tcp length and the query for
+                                * write, this emulates writev */
+                               uint8_t buf[LDNS_RR_BUF_SIZE];
+                               memmove(buf, &len, sizeof(uint16_t));
+                               memmove(buf+sizeof(uint16_t),
+                                       c->tcp_write_pkt,
+                                       c->tcp_write_pkt_len);
+                               r = SSL_write(c->ssl,
+                                       (void*)(buf+c->tcp_write_byte_count),
+                                       c->tcp_write_pkt_len + 2 -
+                                       c->tcp_write_byte_count);
+                       } else {
+                               r = SSL_write(c->ssl,
+                                       (void*)(((uint8_t*)&len)+c->tcp_write_byte_count),
+                                       (int)(sizeof(uint16_t)-c->tcp_write_byte_count));
+                       }
+               } else if(sizeof(uint16_t)+sldns_buffer_remaining(c->buffer) <
                        LDNS_RR_BUF_SIZE) {
                        /* combine the tcp length and the query for write,
                         * this emulates writev */
@@ -1347,20 +1366,32 @@ ssl_handle_write(struct comm_point* c)
                        log_crypto_err("could not SSL_write");
                        return 0;
                }
-               c->tcp_byte_count += r;
-               if(c->tcp_byte_count < sizeof(uint16_t))
-                       return 1;
-               sldns_buffer_set_position(c->buffer, c->tcp_byte_count -
-                       sizeof(uint16_t));
-               if(sldns_buffer_remaining(c->buffer) == 0) {
+               if(c->tcp_write_and_read) {
+                       c->tcp_write_byte_count += r;
+                       if(c->tcp_write_byte_count < sizeof(uint16_t))
+                               return 1;
+               } else {
+                       c->tcp_byte_count += r;
+                       if(c->tcp_byte_count < sizeof(uint16_t))
+                               return 1;
+                       sldns_buffer_set_position(c->buffer, c->tcp_byte_count -
+                               sizeof(uint16_t));
+               }
+               if((!c->tcp_write_and_read && sldns_buffer_remaining(c->buffer) == 0) || (c->tcp_write_and_read && c->tcp_write_byte_count == c->tcp_write_pkt_len + 2)) {
                        tcp_callback_writer(c);
                        return 1;
                }
        }
-       log_assert(sldns_buffer_remaining(c->buffer) > 0);
+       log_assert(c->tcp_write_and_read || sldns_buffer_remaining(c->buffer) > 0);
+       log_assert(!c->tcp_write_and_read || c->tcp_write_byte_count < c->tcp_write_pkt_len + 2);
        ERR_clear_error();
-       r = SSL_write(c->ssl, (void*)sldns_buffer_current(c->buffer),
-               (int)sldns_buffer_remaining(c->buffer));
+       if(c->tcp_write_and_read) {
+               r = SSL_write(c->ssl, (void*)(c->tcp_write_pkt + c->tcp_write_byte_count - 2),
+                       (int)(c->tcp_write_pkt_len + 2 - c->tcp_write_byte_count));
+       } else {
+               r = SSL_write(c->ssl, (void*)sldns_buffer_current(c->buffer),
+                       (int)sldns_buffer_remaining(c->buffer));
+       }
        if(r <= 0) {
                int want = SSL_get_error(c->ssl, r);
                if(want == SSL_ERROR_ZERO_RETURN) {
@@ -1385,9 +1416,13 @@ ssl_handle_write(struct comm_point* c)
                log_crypto_err("could not SSL_write");
                return 0;
        }
-       sldns_buffer_skip(c->buffer, (ssize_t)r);
+       if(c->tcp_write_and_read) {
+               c->tcp_write_byte_count += r;
+       } else {
+               sldns_buffer_skip(c->buffer, (ssize_t)r);
+       }
 
-       if(sldns_buffer_remaining(c->buffer) == 0) {
+       if((!c->tcp_write_and_read && sldns_buffer_remaining(c->buffer) == 0) || (c->tcp_write_and_read && c->tcp_write_byte_count == c->tcp_write_pkt_len + 2)) {
                tcp_callback_writer(c);
        }
        return 1;
@@ -1531,7 +1566,7 @@ comm_point_tcp_handle_write(int fd, struct comm_point* c)
        if(c->tcp_is_reading && !c->ssl)
                return 0;
        log_assert(fd != -1);
-       if(c->tcp_byte_count == 0 && c->tcp_check_nb_connect) {
+       if(((!c->tcp_write_and_read && c->tcp_byte_count == 0) || (c->tcp_write_and_read && c->tcp_write_byte_count == 0)) && c->tcp_check_nb_connect) {
                /* check for pending error from nonblocking connect */
                /* from Stevens, unix network programming, vol1, 3rd ed, p450*/
                int error = 0;
@@ -1581,15 +1616,22 @@ comm_point_tcp_handle_write(int fd, struct comm_point* c)
        if(c->tcp_do_fastopen == 1) {
                /* this form of sendmsg() does both a connect() and send() so need to
                   look for various flavours of error*/
-               uint16_t len = htons(sldns_buffer_limit(buffer));
+               uint16_t len = htons(c->tcp_write_and_read?c->tcp_write_pkt_len:sldns_buffer_limit(buffer));
                struct msghdr msg;
                struct iovec iov[2];
                c->tcp_do_fastopen = 0;
                memset(&msg, 0, sizeof(msg));
-               iov[0].iov_base = (uint8_t*)&len + c->tcp_byte_count;
-               iov[0].iov_len = sizeof(uint16_t) - c->tcp_byte_count;
-               iov[1].iov_base = sldns_buffer_begin(buffer);
-               iov[1].iov_len = sldns_buffer_limit(buffer);
+               if(c->tcp_write_and_read) {
+                       iov[0].iov_base = (uint8_t*)&len + c->tcp_write_byte_count;
+                       iov[0].iov_len = sizeof(uint16_t) - c->tcp_write_byte_count;
+                       iov[1].iov_base = c->tcp_write_pkt;
+                       iov[1].iov_len = c->tcp_write_pkt_len;
+               } else {
+                       iov[0].iov_base = (uint8_t*)&len + c->tcp_byte_count;
+                       iov[0].iov_len = sizeof(uint16_t) - c->tcp_byte_count;
+                       iov[1].iov_base = sldns_buffer_begin(buffer);
+                       iov[1].iov_len = sldns_buffer_limit(buffer);
+               }
                log_assert(iov[0].iov_len > 0);
                msg.msg_name = &c->repinfo.addr;
                msg.msg_namelen = c->repinfo.addrlen;
@@ -1635,12 +1677,18 @@ comm_point_tcp_handle_write(int fd, struct comm_point* c)
                        }
 
                } else {
-                       c->tcp_byte_count += r;
-                       if(c->tcp_byte_count < sizeof(uint16_t))
-                               return 1;
-                       sldns_buffer_set_position(buffer, c->tcp_byte_count - 
-                               sizeof(uint16_t));
-                       if(sldns_buffer_remaining(buffer) == 0) {
+                       if(c->tcp_write_and_read) {
+                               c->tcp_write_byte_count += r;
+                               if(c->tcp_write_byte_count < sizeof(uint16_t))
+                                       return 1;
+                       } else {
+                               c->tcp_byte_count += r;
+                               if(c->tcp_byte_count < sizeof(uint16_t))
+                                       return 1;
+                               sldns_buffer_set_position(buffer, c->tcp_byte_count -
+                                       sizeof(uint16_t));
+                       }
+                       if((!c->tcp_write_and_read && sldns_buffer_remaining(buffer) == 0) || (c->tcp_write_and_read && c->tcp_write_byte_count == c->tcp_write_pkt_len + 2)) {
                                tcp_callback_writer(c);
                                return 1;
                        }
@@ -1648,19 +1696,31 @@ comm_point_tcp_handle_write(int fd, struct comm_point* c)
        }
 #endif /* USE_MSG_FASTOPEN */
 
-       if(c->tcp_byte_count < sizeof(uint16_t)) {
-               uint16_t len = htons(sldns_buffer_limit(buffer));
+       if((c->tcp_write_and_read?c->tcp_write_byte_count:c->tcp_byte_count) < sizeof(uint16_t)) {
+               uint16_t len = htons(c->tcp_write_and_read?c->tcp_write_pkt_len:sldns_buffer_limit(buffer));
 #ifdef HAVE_WRITEV
                struct iovec iov[2];
-               iov[0].iov_base = (uint8_t*)&len + c->tcp_byte_count;
-               iov[0].iov_len = sizeof(uint16_t) - c->tcp_byte_count;
-               iov[1].iov_base = sldns_buffer_begin(buffer);
-               iov[1].iov_len = sldns_buffer_limit(buffer);
+               if(c->tcp_write_and_read) {
+                       iov[0].iov_base = (uint8_t*)&len + c->tcp_write_byte_count;
+                       iov[0].iov_len = sizeof(uint16_t) - c->tcp_write_byte_count;
+                       iov[1].iov_base = c->tcp_write_pkt;
+                       iov[1].iov_len = c->tcp_write_pkt_len;
+               } else {
+                       iov[0].iov_base = (uint8_t*)&len + c->tcp_byte_count;
+                       iov[0].iov_len = sizeof(uint16_t) - c->tcp_byte_count;
+                       iov[1].iov_base = sldns_buffer_begin(buffer);
+                       iov[1].iov_len = sldns_buffer_limit(buffer);
+               }
                log_assert(iov[0].iov_len > 0);
                r = writev(fd, iov, 2);
 #else /* HAVE_WRITEV */
-               r = send(fd, (void*)(((uint8_t*)&len)+c->tcp_byte_count),
-                       sizeof(uint16_t)-c->tcp_byte_count, 0);
+               if(c->tcp_write_and_read) {
+                       r = send(fd, (void*)(((uint8_t*)&len)+c->tcp_write_byte_count),
+                               sizeof(uint16_t)-c->tcp_write_byte_count, 0);
+               } else {
+                       r = send(fd, (void*)(((uint8_t*)&len)+c->tcp_byte_count),
+                               sizeof(uint16_t)-c->tcp_byte_count, 0);
+               }
 #endif /* HAVE_WRITEV */
                if(r == -1) {
 #ifndef USE_WINSOCK
@@ -1699,19 +1759,31 @@ comm_point_tcp_handle_write(int fd, struct comm_point* c)
 #endif
                        return 0;
                }
-               c->tcp_byte_count += r;
-               if(c->tcp_byte_count < sizeof(uint16_t))
-                       return 1;
-               sldns_buffer_set_position(buffer, c->tcp_byte_count - 
-                       sizeof(uint16_t));
-               if(sldns_buffer_remaining(buffer) == 0) {
+               if(c->tcp_write_and_read) {
+                       c->tcp_write_byte_count += r;
+                       if(c->tcp_write_byte_count < sizeof(uint16_t))
+                               return 1;
+               } else {
+                       c->tcp_byte_count += r;
+                       if(c->tcp_byte_count < sizeof(uint16_t))
+                               return 1;
+                       sldns_buffer_set_position(buffer, c->tcp_byte_count -
+                               sizeof(uint16_t));
+               }
+               if((!c->tcp_write_and_read && sldns_buffer_remaining(buffer) == 0) || (c->tcp_write_and_read && c->tcp_write_byte_count == c->tcp_write_pkt_len + 2)) {
                        tcp_callback_writer(c);
                        return 1;
                }
        }
-       log_assert(sldns_buffer_remaining(buffer) > 0);
-       r = send(fd, (void*)sldns_buffer_current(buffer), 
-               sldns_buffer_remaining(buffer), 0);
+       log_assert(c->tcp_write_and_read || sldns_buffer_remaining(buffer) > 0);
+       log_assert(!c->tcp_write_and_read || c->tcp_write_byte_count < c->tcp_write_pkt_len + 2);
+       if(c->tcp_write_and_read) {
+               r = send(fd, (void*)c->tcp_write_pkt + c->tcp_write_byte_count - 2,
+                       c->tcp_write_pkt_len + 2 - c->tcp_write_byte_count, 0);
+       } else {
+               r = send(fd, (void*)sldns_buffer_current(buffer),
+                       sldns_buffer_remaining(buffer), 0);
+       }
        if(r == -1) {
 #ifndef USE_WINSOCK
                if(errno == EINTR || errno == EAGAIN)
@@ -1736,9 +1808,13 @@ comm_point_tcp_handle_write(int fd, struct comm_point* c)
 #endif
                return 0;
        }
-       sldns_buffer_skip(buffer, r);
+       if(c->tcp_write_and_read) {
+               c->tcp_write_byte_count += r;
+       } else {
+               sldns_buffer_skip(buffer, r);
+       }
 
-       if(sldns_buffer_remaining(buffer) == 0) {
+       if((!c->tcp_write_and_read && sldns_buffer_remaining(buffer) == 0) || (c->tcp_write_and_read && c->tcp_write_byte_count == c->tcp_write_pkt_len + 2)) {
                tcp_callback_writer(c);
        }
        
index c044f8938dcefebb1a090f4935e275d0ad7961ca..300592e5bdfcaee0fb22b3245ec81ddc524150b9 100644 (file)
@@ -254,11 +254,16 @@ struct comm_point {
        int tcp_write_and_read;
 
        /** byte count for written length over write channel, for when
-        * tcp_write_and_read is enabled */
+        * tcp_write_and_read is enabled.  When tcp_write_and_read is enabled,
+        * this is the counter for writing, the one for reading is in the
+        * commpoint.buffer sldns buffer.  The counter counts from 0 to
+        * 2+tcp_write_pkt_len, and includes the tcp length bytes. */
        size_t tcp_write_byte_count;
 
        /** packet to write currently over the write channel. for when
-        * tcp_write_and_read is enabled */
+        * tcp_write_and_read is enabled.  When tcp_write_and_read is enabled,
+        * this is the buffer for the written packet, the commpoint.buffer
+        * sldns buffer is the buffer for the received packet. */
        uint8_t* tcp_write_pkt;
        /** length of tcp_write_pkt in bytes */
        size_t tcp_write_pkt_len;