]> git.ipfire.org Git - thirdparty/dnspython.git/commitdiff
Overhaul _destination_and_source.
authorBob Halley <halley@dnspython.org>
Tue, 16 Jun 2020 14:57:59 +0000 (07:57 -0700)
committerBob Halley <halley@dnspython.org>
Tue, 16 Jun 2020 14:57:59 +0000 (07:57 -0700)
We now use dns.inet.low_level_address_tuple() for the low-level tuple
conversion.

We now detect mismatches between source and destination address families.

If a source_port has been specified but we have no idea about the
family, complain.  (This can only happen when 'where' is a URL and no
source address has been specified either.)

'where' MUST be an address literal unless being called by DoH code, but we
tolerated failures in other cases.

In the DoH case where 'where' was a URL and source was specified, the
lack of an address family in the destination caused us to return None
for the source, and thus not set it even though the caller asked for it.
We now infer the address family from the source address in that case.

dns/query.py

index 17f1bae8b85615a1699fc62405ee8b613a469e6d..f224fe3db92f4ccc0f5f64e7916a78bc04d62cd8 100644 (file)
@@ -201,34 +201,47 @@ def _addresses_equal(af, a1, a2):
     return n1 == n2 and a1[1:] == a2[1:]
 
 
-def _destination_and_source(af, where, port, source, source_port,
-                            default_to_inet=True):
+def _destination_and_source(where, port, source, source_port,
+                            where_must_be_address=True):
     # Apply defaults and compute destination and source tuples
     # suitable for use in connect(), sendto(), or bind().
-    if af is None:
-        try:
-            af = dns.inet.af_for_address(where)
-        except Exception:
-            if default_to_inet:
-                af = dns.inet.AF_INET
-    if af == dns.inet.AF_INET:
-        destination = (where, port)
-        if source is not None or source_port != 0:
-            if source is None:
-                source = '0.0.0.0'
-            source = (source, source_port)
-    elif af == dns.inet.AF_INET6:
-        ai_flags = socket.AI_NUMERICHOST
-        ((*_, destination), *_) = socket.getaddrinfo(where, port,
-                                                     flags=ai_flags)
-        if source is not None or source_port != 0:
-            if source is None:
-                source = '::'
-            ((*_, source), *_) = socket.getaddrinfo(source, source_port,
-                                                    flags=ai_flags)
-    else:
-        source = None
-        destination = None
+    af = None
+    destination = None
+    try:
+        af = dns.inet.af_for_address(where)
+        destination = where
+    except Exception:
+        if where_must_be_address:
+            raise
+        # URLs are ok so eat the exception
+    if source:
+        saf = dns.inet.af_for_address(source)
+        if af:
+            # We know the destination af, so source had better agree!
+            if saf != af:
+                raise ValueError('different address families for source ' +
+                                 'and destination')
+        else:
+            # We didn't know the destination af, but we know the source,
+            # so that's our af.
+            af = saf
+    if source_port and not source:
+        # Caller has specified a source_port but not an address, so we
+        # need to return a source, and we need to use the appropriate
+        # wildcard address as the address.
+        if af == dns.inet.AF_INET:
+            source = '0.0.0.0'
+        elif af == dns.inet.AF_INET6:
+            source = '::'
+        else:
+            raise ValueError('source_port specified but address family is '
+                             'unknown')
+    # Convert high-level (address, port) tuples into low-level address
+    # tuples.
+    if destination:
+        destination = dns.inet.low_level_address_tuple((destination, port), af)
+    if source:
+        source = dns.inet.low_level_address_tuple((source, source_port), af)
     return (af, destination, source)
 
 def _make_socket(af, type, source, ssl_context=None, server_hostname=None):
@@ -295,7 +308,7 @@ def https(q, where, timeout=None, port=443, source=None, source_port=0,
         raise NoDOH
 
     wire = q.to_wire()
-    (af, destination, source) = _destination_and_source(None, where, port,
+    (af, destination, source) = _destination_and_source(where, port,
                                                         source, source_port,
                                                         False)
     transport_adapter = None
@@ -480,7 +493,7 @@ def udp(q, where, timeout=None, port=53, source=None, source_port=0,
     """
 
     wire = q.to_wire()
-    (af, destination, source) = _destination_and_source(None, where, port,
+    (af, destination, source) = _destination_and_source(where, port,
                                                         source, source_port)
     (begin_time, expiration) = _compute_times(timeout)
     with contextlib.ExitStack() as stack:
@@ -712,8 +725,8 @@ def tcp(q, where, timeout=None, port=53, source=None, source_port=0,
             sock.getpeername()
             s = sock
         else:
-            (af, destination, source) = _destination_and_source(None, where,
-                                                                port, source,
+            (af, destination, source) = _destination_and_source(where, port,
+                                                                source,
                                                                 source_port)
             s = stack.enter_context(_make_socket(af, socket.SOCK_STREAM,
                                                  source))
@@ -792,7 +805,7 @@ def tls(q, where, timeout=None, port=853, source=None, source_port=0,
 
     wire = q.to_wire()
     (begin_time, expiration) = _compute_times(timeout)
-    (af, destination, source) = _destination_and_source(None, where, port,
+    (af, destination, source) = _destination_and_source(where, port,
                                                         source, source_port)
     if ssl_context is None and not sock:
         ssl_context = ssl.create_default_context()
@@ -879,7 +892,7 @@ def xfr(where, zone, rdtype=dns.rdatatype.AXFR, rdclass=dns.rdataclass.IN,
     if keyring is not None:
         q.use_tsig(keyring, keyname, algorithm=keyalgorithm)
     wire = q.to_wire()
-    (af, destination, source) = _destination_and_source(None, where, port,
+    (af, destination, source) = _destination_and_source(where, port,
                                                         source, source_port)
     if use_udp and rdtype != dns.rdatatype.IXFR:
         raise ValueError('cannot do a UDP AXFR')