]> git.ipfire.org Git - thirdparty/dnspython.git/commitdiff
Grealy simplify our getaddrinfo() implementation by calling the
authorBob Halley <halley@dnspython.org>
Mon, 4 May 2020 15:00:33 +0000 (08:00 -0700)
committerBob Halley <halley@dnspython.org>
Mon, 4 May 2020 15:00:33 +0000 (08:00 -0700)
system's version when we have an address literal for the host.  This
also avoids infinite loops as dns.query.* needs to call getaddrinfo()
to handle scoping correctly.

dns/resolver.py

index bc865633df4c45d77b4db06b6d9c9e2aebfc4b16..e6f145b74536ba130d43a49bc5614a56041a5ec7 100644 (file)
@@ -1222,6 +1222,14 @@ _original_gethostbyaddr = socket.gethostbyaddr
 
 def _getaddrinfo(host=None, service=None, family=socket.AF_UNSPEC, socktype=0,
                  proto=0, flags=0):
+    if flags & socket.AI_NUMERICHOST != 0:
+        # Short circuit directly into the system's getaddrinfo().  We're
+        # not adding any value in this case, and this avoids infinite loops
+        # because dns.query.* needs to call getaddrinfo() for IPv6 scoping
+        # reasons.  We will also do this short circuit below if we
+        # discover that the host is an address literal.
+        return _original_getaddrinfo(host, service, family, socktype, proto,
+                                     flags)
     if flags & (socket.AI_ADDRCONFIG | socket.AI_V4MAPPED) != 0:
         # Not implemented.  We raise a gaierror as opposed to a
         # NotImplementedError as it helps callers handle errors more
@@ -1232,56 +1240,41 @@ def _getaddrinfo(host=None, service=None, family=socket.AF_UNSPEC, socktype=0,
     v6addrs = []
     v4addrs = []
     canonical_name = None
+    # Is host None or an address literal?  If so, use the system's
+    # getaddrinfo().
+    if host is None:
+        return _original_getaddrinfo(host, service, family, socktype,
+                                     proto, flags)
     try:
-        # Is host None or a V6 address literal?
-        if host is None:
-            canonical_name = 'localhost'
-            if flags & socket.AI_PASSIVE != 0:
-                v6addrs.append('::')
-                v4addrs.append('0.0.0.0')
-            else:
-                v6addrs.append('::1')
-                v4addrs.append('127.0.0.1')
-        else:
-            parts = host.split('%')
-            if len(parts) == 2:
-                ahost = parts[0]
-            else:
-                ahost = host
-            addr = dns.ipv6.inet_aton(ahost)
-            v6addrs.append(host)
-            canonical_name = host
+        af = dns.inet.af_for_address(host)
+        return _original_getaddrinfo(host, service, family, socktype,
+                                     proto, flags)
     except Exception:
-        try:
-            # Is it a V4 address literal?
-            addr = dns.ipv4.inet_aton(host)
-            v4addrs.append(host)
-            canonical_name = host
-        except Exception:
-            if flags & socket.AI_NUMERICHOST == 0:
-                try:
-                    if family == socket.AF_INET6 or family == socket.AF_UNSPEC:
-                        v6 = _resolver.query(host, dns.rdatatype.AAAA,
-                                             raise_on_no_answer=False)
-                        # Note that setting host ensures we query the same name
-                        # for A as we did for AAAA.
-                        host = v6.qname
-                        canonical_name = v6.canonical_name.to_text(True)
-                        if v6.rrset is not None:
-                            for rdata in v6.rrset:
-                                v6addrs.append(rdata.address)
-                    if family == socket.AF_INET or family == socket.AF_UNSPEC:
-                        v4 = _resolver.query(host, dns.rdatatype.A,
-                                             raise_on_no_answer=False)
-                        host = v4.qname
-                        canonical_name = v4.canonical_name.to_text(True)
-                        if v4.rrset is not None:
-                            for rdata in v4.rrset:
-                                v4addrs.append(rdata.address)
-                except dns.resolver.NXDOMAIN:
-                    raise socket.gaierror(socket.EAI_NONAME)
-                except Exception:
-                    raise socket.gaierror(socket.EAI_SYSTEM)
+        pass
+    # Something needs resolution!
+    try:
+        if family == socket.AF_INET6 or family == socket.AF_UNSPEC:
+            v6 = _resolver.query(host, dns.rdatatype.AAAA,
+                                 raise_on_no_answer=False)
+            # Note that setting host ensures we query the same name
+            # for A as we did for AAAA.
+            host = v6.qname
+            canonical_name = v6.canonical_name.to_text(True)
+            if v6.rrset is not None:
+                for rdata in v6.rrset:
+                    v6addrs.append(rdata.address)
+        if family == socket.AF_INET or family == socket.AF_UNSPEC:
+            v4 = _resolver.query(host, dns.rdatatype.A,
+                                 raise_on_no_answer=False)
+            host = v4.qname
+            canonical_name = v4.canonical_name.to_text(True)
+            if v4.rrset is not None:
+                for rdata in v4.rrset:
+                    v4addrs.append(rdata.address)
+    except dns.resolver.NXDOMAIN:
+        raise socket.gaierror(socket.EAI_NONAME)
+    except Exception:
+        raise socket.gaierror(socket.EAI_SYSTEM)
     port = None
     try:
         # Is it a port literal?