]> git.ipfire.org Git - thirdparty/dnspython.git/commitdiff
add dns.resolver.override_system_resolver() and dns.resolver.restore_system_resolver()
authorBob Halley <halley@nominum.com>
Wed, 13 Jul 2011 19:55:30 +0000 (12:55 -0700)
committerBob Halley <halley@nominum.com>
Wed, 13 Jul 2011 19:55:30 +0000 (12:55 -0700)
ChangeLog
dns/resolver.py

index 08d93d83c398fb2db3a0bb756932f461f639cc2c..df20c8734012360c29cccda5ecaed586fd9bec5b 100644 (file)
--- a/ChangeLog
+++ b/ChangeLog
@@ -1,3 +1,15 @@
+2011-07-13  Bob Halley  <halley@dnspython.org>
+
+       * dns/resolver.py: dns.resolver.override_system_resolver()
+         overrides the socket module's versions of getaddrinfo(),
+         getnameinfo(), getfqdn(), gethostbyname(), gethostbyname_ex() and
+         gethostbyaddr() with an implementation which uses a dnspython stub
+         resolver instead of the system's stub resolver.  This can be
+         useful in testing situations where you want to control the
+         resolution behavior of python code without having to change the
+         system's resolver settings (e.g. /etc/resolv.conf).
+         dns.resolver.restore_system_resolver() undoes the change.
+
 2011-07-08  Bob Halley  <halley@dnspython.org>
 
        * dns/ipv4.py: dnspython now provides its own, stricter, versions
index 7c199be4990ee631d702617ea6b03e9488599519..34ea45ef499ac394ad1363deb5bd449414335e46 100644 (file)
@@ -23,12 +23,15 @@ import sys
 import time
 
 import dns.exception
+import dns.ipv4
+import dns.ipv6
 import dns.message
 import dns.name
 import dns.query
 import dns.rcode
 import dns.rdataclass
 import dns.rdatatype
+import dns.reversename
 
 if sys.platform == 'win32':
     import _winreg
@@ -794,3 +797,235 @@ def zone_for_name(name, rdclass=dns.rdataclass.IN, tcp=False, resolver=None):
             name = name.parent()
         except dns.name.NoParent:
             raise NoRootSOA
+
+#
+# Support for overriding the system resolver for all python code in the
+# running process.
+#
+
+_protocols_for_socktype = {
+    socket.SOCK_DGRAM : [socket.SOL_UDP],
+    socket.SOCK_STREAM : [socket.SOL_TCP],
+    }
+
+_resolver = None
+_original_getaddrinfo = socket.getaddrinfo
+_original_getnameinfo = socket.getnameinfo
+_original_getfqdn = socket.getfqdn
+_original_gethostbyname = socket.gethostbyname
+_original_gethostbyname_ex = socket.gethostbyname_ex
+_original_gethostbyaddr = socket.gethostbyaddr
+
+def _getaddrinfo(host=None, service=None, family=socket.AF_UNSPEC, socktype=0,
+                 proto=0, flags=0):
+    if flags & (socket.AI_ADDRCONFIG|socket.AI_V4MAPPED) != 0:
+        raise NotImplementedError
+    if host is None and service is None:
+        raise socket.gaierror(socket.EAI_NONAME)
+    v6addrs = []
+    v4addrs = []
+    canonical_name = None
+    try:
+        # Is host None or a V6 address literal?
+        if host is None:
+            canonical_name = 'localhost'
+            if flags & socket.AI_PASSIVE != 0:
+                v6addrs.append('::')
+                v4addrs.append('0.0.0.0')
+            else:
+                v6addrs.append('::1')
+                v4addrs.append('127.0.0.1')
+        else:
+            parts = host.split('%')
+            if len(parts) == 2:
+                ahost = parts[0]
+            else:
+                ahost = host
+            addr = dns.ipv6.inet_aton(ahost)
+            v6addrs.append(host)
+            canonical_name = host
+    except:
+        try:
+            # Is it a V4 address literal?
+            addr = dns.ipv4.inet_aton(host)
+            v4addrs.append(host)
+            canonical_name = host
+        except:
+            if flags & socket.AI_NUMERICHOST == 0:
+                try:
+                    qname = None
+                    if family == socket.AF_INET6 or family == socket.AF_UNSPEC:
+                        v6 = _resolver.query(host, dns.rdatatype.AAAA,
+                                             raise_on_no_answer=False)
+                        # Note that setting host ensures we query the same name
+                        # for A as we did for AAAA.
+                        host = v6.qname
+                        canonical_name = v6.canonical_name.to_text(True)
+                        if v6.rrset is not None:
+                            for rdata in v6.rrset:
+                                v6addrs.append(rdata.address)
+                    if family == socket.AF_INET or family == socket.AF_UNSPEC:
+                        v4 = _resolver.query(host, dns.rdatatype.A,
+                                             raise_on_no_answer=False)
+                        host = v4.qname
+                        canonical_name = v4.canonical_name.to_text(True)
+                        if v4.rrset is not None:
+                            for rdata in v4.rrset:
+                                v4addrs.append(rdata.address)
+                except dns.resolver.NXDOMAIN:
+                    raise socket.gaierror(socket.EAI_NONAME)
+                except:
+                    raise socket.gaierror(socket.EAI_SYSTEM)
+    port = None
+    try:
+        # Is it a port literal?
+        if service is None:
+            port = 0
+        else:
+            port = int(service)
+    except:
+        if flags & socket.AI_NUMERICSERV == 0:
+            try:
+                port = socket.getservbyname(service)
+            except:
+                pass
+    if port is None:
+        raise socket.gaierror(socket.EAI_NONAME)
+    tuples = []
+    if socktype == 0:
+        socktypes = [socket.SOCK_DGRAM, socket.SOCK_STREAM]
+    else:
+        socktypes = [socktype]
+    if flags & socket.AI_CANONNAME != 0:
+        cname = canonical_name
+    else:
+        cname = ''
+    if family == socket.AF_INET6 or family == socket.AF_UNSPEC:
+        for addr in v6addrs:
+            for socktype in socktypes:
+                for proto in _protocols_for_socktype[socktype]:
+                    tuples.append((socket.AF_INET6, socktype, proto,
+                                   cname, (addr, port, 0, 0)))
+    if family == socket.AF_INET or family == socket.AF_UNSPEC:
+        for addr in v4addrs:
+            for socktype in socktypes:
+                for proto in _protocols_for_socktype[socktype]:
+                    tuples.append((socket.AF_INET, socktype, proto,
+                                   cname, (addr, port)))
+    if len(tuples) == 0:
+        raise socket.gaierror(socket.EAI_NONAME)
+    return tuples
+
+def _getnameinfo(sockaddr, flags=0):
+    host = sockaddr[0]
+    port = sockaddr[1]
+    if len(sockaddr) == 4:
+        scope = sockaddr[3]
+        family = socket.AF_INET6
+    else:
+        scope = None
+        family = socket.AF_INET
+    tuples = _getaddrinfo(host, port, family, socket.SOCK_STREAM,
+                          socket.SOL_TCP, 0)
+    if len(tuples) > 1:
+        raise socket.error('sockaddr resolved to multiple addresses')
+    addr = tuples[0][4][0]
+    if flags & socket.NI_DGRAM:
+        pname = 'udp'
+    else:
+        pname = 'tcp'
+    qname = dns.reversename.from_address(addr)
+    if flags & socket.NI_NUMERICHOST == 0:
+        try:
+            answer = _resolver.query(qname, 'PTR')
+            hostname = answer.rrset[0].target.to_text(True)
+        except (dns.resolver.NXDOMAIN, dns.resolver.NoAnswer):
+            if flags & socket.NI_NAMEREQD:
+                raise socket.gaierror(socket.EAI_NONAME)
+            hostname = addr
+            if scope is not None:
+                hostname += '%' + str(scope)
+    else:
+        hostname = addr
+        if scope is not None:
+            hostname += '%' + str(scope)
+    if flags & socket.NI_NUMERICSERV:
+        service = str(port)
+    else:
+        service = socket.getservbyport(port, pname)
+    return (hostname, service)
+
+def _getfqdn(name=None):
+    if name is None:
+        name = socket.gethostname()
+    return _getnameinfo(_getaddrinfo(name, 80)[0][4])[0]
+
+def _gethostbyname(name):
+    return _gethostbyname_ex(name)[2][0]
+
+def _gethostbyname_ex(name):
+    aliases = []
+    addresses = []
+    tuples = _getaddrinfo(name, 0, socket.AF_INET, socket.SOCK_STREAM,
+                          socket.SOL_TCP, socket.AI_CANONNAME)
+    canonical = tuples[0][3]
+    for item in tuples:
+        addresses.append(item[4][0])
+    # XXX we just ignore aliases
+    return (canonical, aliases, addresses)
+
+def _gethostbyaddr(ip):
+    try:
+        addr = dns.ipv6.inet_aton(ip)
+        sockaddr = (ip, 80, 0, 0)
+        family = socket.AF_INET6
+    except:
+        sockaddr = (ip, 80)
+        family = socket.AF_INET
+    (name, port) = _getnameinfo(sockaddr, socket.NI_NAMEREQD)
+    aliases = []
+    addresses = []
+    tuples = _getaddrinfo(name, 0, family, socket.SOCK_STREAM, socket.SOL_TCP,
+                          socket.AI_CANONNAME)
+    canonical = tuples[0][3]
+    for item in tuples:
+        addresses.append(item[4][0])
+    # XXX we just ignore aliases
+    return (canonical, aliases, addresses)
+
+def override_system_resolver(resolver=None):
+    """Override the system resolver routines in the socket module with
+    versions which use dnspython's resolver.
+
+    This can be useful in testing situations where you want to control
+    the resolution behavior of python code without having to change
+    the system's resolver settings (e.g. /etc/resolv.conf).
+
+    The resolver to use may be specified; if it's not, the default
+    resolver will be used.
+
+    @param resolver: the resolver to use
+    @type resolver: dns.resolver.Resolver object or None
+    """
+    if resolver is None:
+        resolver = get_default_resolver()
+    global _resolver
+    _resolver = resolver
+    socket.getaddrinfo = _getaddrinfo
+    socket.getnameinfo = _getnameinfo
+    socket.getfqdn = _getfqdn
+    socket.gethostbyname = _gethostbyname
+    socket.gethostbyname_ex = _gethostbyname_ex
+    socket.gethostbyaddr = _gethostbyaddr
+
+def restore_system_resolver():
+    """Undo the effects of override_system_resolver().
+    """
+    global _resolver
+    _resolver = None
+    socket.getaddrinfo = _original_getaddrinfo
+    socket.getnameinfo = _original_getnameinfo
+    socket.getfqdn = _original_getfqdn
+    socket.gethostbyname = _original_gethostbyname
+    socket.gethostbyname_ex = _original_gethostbyname_ex
+    socket.gethostbyaddr = _original_gethostbyaddr