From: Bob Halley Date: Tue, 16 Jun 2020 14:57:59 +0000 (-0700) Subject: Overhaul _destination_and_source. X-Git-Tag: v2.0.0rc1~96 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=9833924a4d4c58e35242f370539f7e6e1e61a4b9;p=thirdparty%2Fdnspython.git Overhaul _destination_and_source. We now use dns.inet.low_level_address_tuple() for the low-level tuple conversion. We now detect mismatches between source and destination address families. If a source_port has been specified but we have no idea about the family, complain. (This can only happen when 'where' is a URL and no source address has been specified either.) 'where' MUST be an address literal unless being called by DoH code, but we tolerated failures in other cases. In the DoH case where 'where' was a URL and source was specified, the lack of an address family in the destination caused us to return None for the source, and thus not set it even though the caller asked for it. We now infer the address family from the source address in that case. --- diff --git a/dns/query.py b/dns/query.py index 17f1bae8..f224fe3d 100644 --- a/dns/query.py +++ b/dns/query.py @@ -201,34 +201,47 @@ def _addresses_equal(af, a1, a2): return n1 == n2 and a1[1:] == a2[1:] -def _destination_and_source(af, where, port, source, source_port, - default_to_inet=True): +def _destination_and_source(where, port, source, source_port, + where_must_be_address=True): # Apply defaults and compute destination and source tuples # suitable for use in connect(), sendto(), or bind(). - if af is None: - try: - af = dns.inet.af_for_address(where) - except Exception: - if default_to_inet: - af = dns.inet.AF_INET - if af == dns.inet.AF_INET: - destination = (where, port) - if source is not None or source_port != 0: - if source is None: - source = '0.0.0.0' - source = (source, source_port) - elif af == dns.inet.AF_INET6: - ai_flags = socket.AI_NUMERICHOST - ((*_, destination), *_) = socket.getaddrinfo(where, port, - flags=ai_flags) - if source is not None or source_port != 0: - if source is None: - source = '::' - ((*_, source), *_) = socket.getaddrinfo(source, source_port, - flags=ai_flags) - else: - source = None - destination = None + af = None + destination = None + try: + af = dns.inet.af_for_address(where) + destination = where + except Exception: + if where_must_be_address: + raise + # URLs are ok so eat the exception + if source: + saf = dns.inet.af_for_address(source) + if af: + # We know the destination af, so source had better agree! + if saf != af: + raise ValueError('different address families for source ' + + 'and destination') + else: + # We didn't know the destination af, but we know the source, + # so that's our af. + af = saf + if source_port and not source: + # Caller has specified a source_port but not an address, so we + # need to return a source, and we need to use the appropriate + # wildcard address as the address. + if af == dns.inet.AF_INET: + source = '0.0.0.0' + elif af == dns.inet.AF_INET6: + source = '::' + else: + raise ValueError('source_port specified but address family is ' + 'unknown') + # Convert high-level (address, port) tuples into low-level address + # tuples. + if destination: + destination = dns.inet.low_level_address_tuple((destination, port), af) + if source: + source = dns.inet.low_level_address_tuple((source, source_port), af) return (af, destination, source) def _make_socket(af, type, source, ssl_context=None, server_hostname=None): @@ -295,7 +308,7 @@ def https(q, where, timeout=None, port=443, source=None, source_port=0, raise NoDOH wire = q.to_wire() - (af, destination, source) = _destination_and_source(None, where, port, + (af, destination, source) = _destination_and_source(where, port, source, source_port, False) transport_adapter = None @@ -480,7 +493,7 @@ def udp(q, where, timeout=None, port=53, source=None, source_port=0, """ wire = q.to_wire() - (af, destination, source) = _destination_and_source(None, where, port, + (af, destination, source) = _destination_and_source(where, port, source, source_port) (begin_time, expiration) = _compute_times(timeout) with contextlib.ExitStack() as stack: @@ -712,8 +725,8 @@ def tcp(q, where, timeout=None, port=53, source=None, source_port=0, sock.getpeername() s = sock else: - (af, destination, source) = _destination_and_source(None, where, - port, source, + (af, destination, source) = _destination_and_source(where, port, + source, source_port) s = stack.enter_context(_make_socket(af, socket.SOCK_STREAM, source)) @@ -792,7 +805,7 @@ def tls(q, where, timeout=None, port=853, source=None, source_port=0, wire = q.to_wire() (begin_time, expiration) = _compute_times(timeout) - (af, destination, source) = _destination_and_source(None, where, port, + (af, destination, source) = _destination_and_source(where, port, source, source_port) if ssl_context is None and not sock: ssl_context = ssl.create_default_context() @@ -879,7 +892,7 @@ def xfr(where, zone, rdtype=dns.rdatatype.AXFR, rdclass=dns.rdataclass.IN, if keyring is not None: q.use_tsig(keyring, keyname, algorithm=keyalgorithm) wire = q.to_wire() - (af, destination, source) = _destination_and_source(None, where, port, + (af, destination, source) = _destination_and_source(where, port, source, source_port) if use_udp and rdtype != dns.rdatatype.IXFR: raise ValueError('cannot do a UDP AXFR')