From: Brian Wellington Date: Fri, 30 Aug 2019 18:42:12 +0000 (-0700) Subject: Improve TCP connect behavior. X-Git-Tag: v2.0.0rc1~361^2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=d1d57c87904a0d69a93180054abad52f0b9dca0f;p=thirdparty%2Fdnspython.git Improve TCP connect behavior. Before this change, the _connect() method would start the connection process, but not wait for it to complete. This would leave the socket in an indeterminate state until some other code checked for writability, and would lose the error code if the connect failed. This changes _connect() to wait for the connection to complete, and raises and exception with the appropriate error code if it fails. --- diff --git a/dns/query.py b/dns/query.py index 20c953cb..fd8988ec 100644 --- a/dns/query.py +++ b/dns/query.py @@ -20,6 +20,7 @@ from __future__ import generators import errno +import os import select import socket import struct @@ -419,7 +420,7 @@ def receive_tcp(sock, expiration=None, one_rr_per_rrset=False, ignore_trailing=ignore_trailing) return (r, received_time) -def _connect(s, address): +def _connect(s, address, expiration): try: s.connect(address) except socket.error: @@ -431,6 +432,10 @@ def _connect(s, address): v_err = v[0] if v_err not in [errno.EINPROGRESS, errno.EWOULDBLOCK, errno.EALREADY]: raise v + _wait_for_writable(s, expiration) + err = s.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR) + if err != 0: + raise OSError(err, os.strerror(err)) from None def tcp(q, where, timeout=None, port=53, af=None, source=None, source_port=0, @@ -479,7 +484,7 @@ def tcp(q, where, timeout=None, port=53, af=None, source=None, source_port=0, begin_time = time.time() if source is not None: s.bind(source) - _connect(s, destination) + _connect(s, destination, expiration) send_tcp(s, wire, expiration) (r, received_time) = receive_tcp(s, expiration, one_rr_per_rrset, q.keyring, q.mac, ignore_trailing) @@ -580,7 +585,7 @@ def xfr(where, zone, rdtype=dns.rdatatype.AXFR, rdclass=dns.rdataclass.IN, if source is not None: s.bind(source) expiration = _compute_expiration(lifetime) - _connect(s, destination) + _connect(s, destination, expiration) l = len(wire) if use_udp: _wait_for_writable(s, expiration)