From: Brian Wellington Date: Fri, 17 Jul 2020 23:37:53 +0000 (-0700) Subject: Changes to blocking model. X-Git-Tag: v2.1.0rc1~176^2~1 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=4c0fe5541e36e06fccf1a85028bc289d3070374e;p=thirdparty%2Fdnspython.git Changes to blocking model. Before this change, the synchronous code would check sockets for readability or writability before doing nonblocking read or write. This changes them to attempt the read or write first, and then block if the operation could not complete. This also removes the no-longer-needed getpeername() call in tcp(), which was needed to deal with the case where an unconnected socket was passed in; waiting for writability would block rather than immediately return an error. By attempting the write first, we get the error immediately. --- diff --git a/dns/query.py b/dns/query.py index eb827715..dbf9f77f 100644 --- a/dns/query.py +++ b/dns/query.py @@ -342,6 +342,33 @@ def https(q, where, timeout=None, port=443, source=None, source_port=0, raise BadResponse return r +def _udp_recv(sock, max_size, expiration): + """Reads a datagram from the socket. + A Timeout exception will be raised if the operation is not completed + by the expiration time. + """ + while True: + try: + return sock.recvfrom(max_size) + except BlockingIOError: + _wait_for_readable(sock, expiration) + + +def _udp_send(sock, data, destination, expiration): + """Sends the specified datagram to destination over the socket. + A Timeout exception will be raised if the operation is not completed + by the expiration time. + """ + while True: + try: + if destination: + return sock.sendto(data, destination) + else: + return sock.send(data) + except BlockingIOError: + _wait_for_writable(sock, expiration) + + def send_udp(sock, what, destination, expiration=None): """Send a DNS message to the specified UDP socket. @@ -361,9 +388,8 @@ def send_udp(sock, what, destination, expiration=None): if isinstance(what, dns.message.Message): what = what.to_wire() - _wait_for_writable(sock, expiration) sent_time = time.time() - n = sock.sendto(what, destination) + n = _udp_send(sock, what, destination, expiration) return (n, sent_time) @@ -413,9 +439,8 @@ def receive_udp(sock, destination=None, expiration=None, """ wire = b'' - while 1: - _wait_for_readable(sock, expiration) - (wire, from_address) = sock.recvfrom(65535) + while True: + (wire, from_address) = _udp_recv(sock, 65535, expiration) if _matches_destination(sock.family, from_address, destination, ignore_unexpected): break @@ -553,18 +578,16 @@ def _net_read(sock, count, expiration): """ s = b'' while count > 0: - _wait_for_readable(sock, expiration) try: n = sock.recv(count) - except ssl.SSLWantReadError: # pragma: no cover - continue + if n == b'': + raise EOFError + count -= len(n) + s += n + except (BlockingIOError, ssl.SSLWantReadError): + _wait_for_readable(sock, expiration) except ssl.SSLWantWriteError: # pragma: no cover _wait_for_writable(sock, expiration) - continue - if n == b'': - raise EOFError - count = count - len(n) - s = s + n return s @@ -576,14 +599,12 @@ def _net_write(sock, data, expiration): current = 0 l = len(data) while current < l: - _wait_for_writable(sock, expiration) try: current += sock.send(data[current:]) + except (BlockingIOError, ssl.SSLWantWriteError): + _wait_for_writable(sock, expiration) except ssl.SSLWantReadError: # pragma: no cover _wait_for_readable(sock, expiration) - continue - except ssl.SSLWantWriteError: # pragma: no cover - continue def send_tcp(sock, what, expiration=None): @@ -607,7 +628,6 @@ def send_tcp(sock, what, expiration=None): # avoid writev() or doing a short write that would get pushed # onto the net tcpmsg = struct.pack("!H", l) + what - _wait_for_writable(sock, expiration) sent_time = time.time() _net_write(sock, tcpmsg, expiration) return (len(tcpmsg), sent_time) @@ -697,11 +717,6 @@ def tcp(q, where, timeout=None, port=53, source=None, source_port=0, (begin_time, expiration) = _compute_times(timeout) with contextlib.ExitStack() as stack: if sock: - # - # Verify that the socket is connected, as if it's not connected, - # it's not writable, and the polling in send_tcp() will time out or - # hang forever. - sock.getpeername() s = sock else: (af, destination, source) = _destination_and_source(where, port, @@ -881,8 +896,7 @@ def xfr(where, zone, rdtype=dns.rdatatype.AXFR, rdclass=dns.rdataclass.IN, _connect(s, destination, expiration) l = len(wire) if use_udp: - _wait_for_writable(s, expiration) - s.send(wire) + _udp_send(s, wire, None, expiration) else: tcpmsg = struct.pack("!H", l) + wire _net_write(s, tcpmsg, expiration) @@ -903,8 +917,7 @@ def xfr(where, zone, rdtype=dns.rdatatype.AXFR, rdclass=dns.rdataclass.IN, (expiration is not None and mexpiration > expiration): mexpiration = expiration if use_udp: - _wait_for_readable(s, expiration) - (wire, from_address) = s.recvfrom(65535) + (wire, from_address) = _udp_recv(s, 65535, expiration) else: ldata = _net_read(s, 2, mexpiration) (l,) = struct.unpack("!H", ldata)