]> git.ipfire.org Git - thirdparty/dnspython.git/commitdiff
Changes to blocking model.
authorBrian Wellington <bwelling@xbill.org>
Fri, 17 Jul 2020 23:37:53 +0000 (16:37 -0700)
committerBrian Wellington <bwelling@xbill.org>
Fri, 17 Jul 2020 23:39:40 +0000 (16:39 -0700)
Before this change, the synchronous code would check sockets for
readability or writability before doing nonblocking read or write.
This changes them to attempt the read or write first, and then block
if the operation could not complete.

This also removes the no-longer-needed getpeername() call in tcp(),
which was needed to deal with the case where an unconnected socket was
passed in; waiting for writability would block rather than immediately
return an error.  By attempting the write first, we get the error
immediately.

dns/query.py

index eb82771564ce0a1f564700862945e570b5b2c846..dbf9f77fbe8d1474ab9686a89f1ef840aad5c027 100644 (file)
@@ -342,6 +342,33 @@ def https(q, where, timeout=None, port=443, source=None, source_port=0,
         raise BadResponse
     return r
 
+def _udp_recv(sock, max_size, expiration):
+    """Reads a datagram from the socket.
+    A Timeout exception will be raised if the operation is not completed
+    by the expiration time.
+    """
+    while True:
+        try:
+            return sock.recvfrom(max_size)
+        except BlockingIOError:
+            _wait_for_readable(sock, expiration)
+
+
+def _udp_send(sock, data, destination, expiration):
+    """Sends the specified datagram to destination over the socket.
+    A Timeout exception will be raised if the operation is not completed
+    by the expiration time.
+    """
+    while True:
+        try:
+            if destination:
+                return sock.sendto(data, destination)
+            else:
+                return sock.send(data)
+        except BlockingIOError:
+            _wait_for_writable(sock, expiration)
+
+
 def send_udp(sock, what, destination, expiration=None):
     """Send a DNS message to the specified UDP socket.
 
@@ -361,9 +388,8 @@ def send_udp(sock, what, destination, expiration=None):
 
     if isinstance(what, dns.message.Message):
         what = what.to_wire()
-    _wait_for_writable(sock, expiration)
     sent_time = time.time()
-    n = sock.sendto(what, destination)
+    n = _udp_send(sock, what, destination, expiration)
     return (n, sent_time)
 
 
@@ -413,9 +439,8 @@ def receive_udp(sock, destination=None, expiration=None,
     """
 
     wire = b''
-    while 1:
-        _wait_for_readable(sock, expiration)
-        (wire, from_address) = sock.recvfrom(65535)
+    while True:
+        (wire, from_address) = _udp_recv(sock, 65535, expiration)
         if _matches_destination(sock.family, from_address, destination,
                                 ignore_unexpected):
             break
@@ -553,18 +578,16 @@ def _net_read(sock, count, expiration):
     """
     s = b''
     while count > 0:
-        _wait_for_readable(sock, expiration)
         try:
             n = sock.recv(count)
-        except ssl.SSLWantReadError:  # pragma: no cover
-            continue
+            if n == b'':
+                raise EOFError
+            count -= len(n)
+            s += n
+        except (BlockingIOError, ssl.SSLWantReadError):
+            _wait_for_readable(sock, expiration)
         except ssl.SSLWantWriteError:  # pragma: no cover
             _wait_for_writable(sock, expiration)
-            continue
-        if n == b'':
-            raise EOFError
-        count = count - len(n)
-        s = s + n
     return s
 
 
@@ -576,14 +599,12 @@ def _net_write(sock, data, expiration):
     current = 0
     l = len(data)
     while current < l:
-        _wait_for_writable(sock, expiration)
         try:
             current += sock.send(data[current:])
+        except (BlockingIOError, ssl.SSLWantWriteError):
+            _wait_for_writable(sock, expiration)
         except ssl.SSLWantReadError:  # pragma: no cover
             _wait_for_readable(sock, expiration)
-            continue
-        except ssl.SSLWantWriteError:  # pragma: no cover
-            continue
 
 
 def send_tcp(sock, what, expiration=None):
@@ -607,7 +628,6 @@ def send_tcp(sock, what, expiration=None):
     # avoid writev() or doing a short write that would get pushed
     # onto the net
     tcpmsg = struct.pack("!H", l) + what
-    _wait_for_writable(sock, expiration)
     sent_time = time.time()
     _net_write(sock, tcpmsg, expiration)
     return (len(tcpmsg), sent_time)
@@ -697,11 +717,6 @@ def tcp(q, where, timeout=None, port=53, source=None, source_port=0,
     (begin_time, expiration) = _compute_times(timeout)
     with contextlib.ExitStack() as stack:
         if sock:
-            #
-            # Verify that the socket is connected, as if it's not connected,
-            # it's not writable, and the polling in send_tcp() will time out or
-            # hang forever.
-            sock.getpeername()
             s = sock
         else:
             (af, destination, source) = _destination_and_source(where, port,
@@ -881,8 +896,7 @@ def xfr(where, zone, rdtype=dns.rdatatype.AXFR, rdclass=dns.rdataclass.IN,
         _connect(s, destination, expiration)
         l = len(wire)
         if use_udp:
-            _wait_for_writable(s, expiration)
-            s.send(wire)
+            _udp_send(s, wire, None, expiration)
         else:
             tcpmsg = struct.pack("!H", l) + wire
             _net_write(s, tcpmsg, expiration)
@@ -903,8 +917,7 @@ def xfr(where, zone, rdtype=dns.rdatatype.AXFR, rdclass=dns.rdataclass.IN,
                (expiration is not None and mexpiration > expiration):
                 mexpiration = expiration
             if use_udp:
-                _wait_for_readable(s, expiration)
-                (wire, from_address) = s.recvfrom(65535)
+                (wire, from_address) = _udp_recv(s, 65535, expiration)
             else:
                 ldata = _net_read(s, 2, mexpiration)
                 (l,) = struct.unpack("!H", ldata)