]> git.ipfire.org Git - thirdparty/dnspython.git/commitdiff
Improve TCP connect behavior. 389/head
authorBrian Wellington <bwelling@xbill.org>
Fri, 30 Aug 2019 18:42:12 +0000 (11:42 -0700)
committerBrian Wellington <bwelling@xbill.org>
Fri, 30 Aug 2019 18:42:12 +0000 (11:42 -0700)
Before this change, the _connect() method would start the connection
process, but not wait for it to complete.  This would leave the socket
in an indeterminate state until some other code checked for writability,
and would lose the error code if the connect failed.

This changes _connect() to wait for the connection to complete, and
raises and exception with the appropriate error code if it fails.

dns/query.py

index 20c953cb328c78db3764539cee83c39807e84567..fd8988ec81d10f811d771091d3a03bf99d78c4d2 100644 (file)
@@ -20,6 +20,7 @@
 from __future__ import generators
 
 import errno
+import os
 import select
 import socket
 import struct
@@ -419,7 +420,7 @@ def receive_tcp(sock, expiration=None, one_rr_per_rrset=False,
                               ignore_trailing=ignore_trailing)
     return (r, received_time)
 
-def _connect(s, address):
+def _connect(s, address, expiration):
     try:
         s.connect(address)
     except socket.error:
@@ -431,6 +432,10 @@ def _connect(s, address):
             v_err = v[0]
         if v_err not in [errno.EINPROGRESS, errno.EWOULDBLOCK, errno.EALREADY]:
             raise v
+        _wait_for_writable(s, expiration)
+        err = s.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR)
+        if err != 0:
+            raise OSError(err, os.strerror(err)) from None
 
 
 def tcp(q, where, timeout=None, port=53, af=None, source=None, source_port=0,
@@ -479,7 +484,7 @@ def tcp(q, where, timeout=None, port=53, af=None, source=None, source_port=0,
         begin_time = time.time()
         if source is not None:
             s.bind(source)
-        _connect(s, destination)
+        _connect(s, destination, expiration)
         send_tcp(s, wire, expiration)
         (r, received_time) = receive_tcp(s, expiration, one_rr_per_rrset,
                                          q.keyring, q.mac, ignore_trailing)
@@ -580,7 +585,7 @@ def xfr(where, zone, rdtype=dns.rdatatype.AXFR, rdclass=dns.rdataclass.IN,
     if source is not None:
         s.bind(source)
     expiration = _compute_expiration(lifetime)
-    _connect(s, destination)
+    _connect(s, destination, expiration)
     l = len(wire)
     if use_udp:
         _wait_for_writable(s, expiration)