From d1d57c87904a0d69a93180054abad52f0b9dca0f Mon Sep 17 00:00:00 2001 From: Brian Wellington Date: Fri, 30 Aug 2019 11:42:12 -0700 Subject: [PATCH] Improve TCP connect behavior. Before this change, the _connect() method would start the connection process, but not wait for it to complete. This would leave the socket in an indeterminate state until some other code checked for writability, and would lose the error code if the connect failed. This changes _connect() to wait for the connection to complete, and raises and exception with the appropriate error code if it fails. --- dns/query.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/dns/query.py b/dns/query.py index 20c953cb..fd8988ec 100644 --- a/dns/query.py +++ b/dns/query.py @@ -20,6 +20,7 @@ from __future__ import generators import errno +import os import select import socket import struct @@ -419,7 +420,7 @@ def receive_tcp(sock, expiration=None, one_rr_per_rrset=False, ignore_trailing=ignore_trailing) return (r, received_time) -def _connect(s, address): +def _connect(s, address, expiration): try: s.connect(address) except socket.error: @@ -431,6 +432,10 @@ def _connect(s, address): v_err = v[0] if v_err not in [errno.EINPROGRESS, errno.EWOULDBLOCK, errno.EALREADY]: raise v + _wait_for_writable(s, expiration) + err = s.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR) + if err != 0: + raise OSError(err, os.strerror(err)) from None def tcp(q, where, timeout=None, port=53, af=None, source=None, source_port=0, @@ -479,7 +484,7 @@ def tcp(q, where, timeout=None, port=53, af=None, source=None, source_port=0, begin_time = time.time() if source is not None: s.bind(source) - _connect(s, destination) + _connect(s, destination, expiration) send_tcp(s, wire, expiration) (r, received_time) = receive_tcp(s, expiration, one_rr_per_rrset, q.keyring, q.mac, ignore_trailing) @@ -580,7 +585,7 @@ def xfr(where, zone, rdtype=dns.rdatatype.AXFR, rdclass=dns.rdataclass.IN, if source is not None: s.bind(source) expiration = _compute_expiration(lifetime) - _connect(s, destination) + _connect(s, destination, expiration) l = len(wire) if use_udp: _wait_for_writable(s, expiration) -- 2.47.3