From: Bob Halley Date: Sat, 1 Oct 2005 03:18:11 +0000 (+0000) Subject: add dns.resolver.zone_from_name() and .get_default_resovler() X-Git-Tag: v1.3.5~12 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=2cf8da1e004f17dca34c6f4eee12a59ab55b37d9;p=thirdparty%2Fdnspython.git add dns.resolver.zone_from_name() and .get_default_resovler() --- diff --git a/ChangeLog b/ChangeLog index 88ec663d..226fa768 100644 --- a/ChangeLog +++ b/ChangeLog @@ -1,3 +1,11 @@ +2005-10-01 Bob Halley + + * dns/resolver.py: Added zone_for_name() helper, which returns + the name of the zone which contains the specified name. + + * dns/resolver.py: Added get_default_resolver(), which returns + the default resolver, initializing it if necessary. + 2005-09-29 Bob Halley * dns/resolver.py (Resolver._compute_timeout): If time goes diff --git a/dns/resolver.py b/dns/resolver.py index 7433d4a3..ca47e31f 100644 --- a/dns/resolver.py +++ b/dns/resolver.py @@ -52,6 +52,17 @@ class NoNameservers(dns.exception.DNSException): """No non-broken nameservers are available to answer the query.""" pass +class NotAbsolute(dns.exception.DNSException): + """Raised if an absolute domain name is required but a relative name + was provided.""" + pass + +class NoRootSOA(dns.exception.DNSException): + """Raised if for some reason there is no SOA at the root name. + This should never happen!""" + pass + + class Answer(object): """DNS stub resolver answer @@ -602,6 +613,13 @@ class Resolver(object): default_resolver = None +def get_default_resolver(): + """Get the default resolver, initializing it if necessary.""" + global default_resolver + if default_resolver is None: + default_resolver = Resolver() + return default_resolver + def query(qname, rdtype=dns.rdatatype.A, rdclass=dns.rdataclass.IN, tcp=False): """Query nameservers to find the answer to the question. @@ -610,7 +628,31 @@ def query(qname, rdtype=dns.rdatatype.A, rdclass=dns.rdataclass.IN, object to make the query. @see: L{dns.resolver.Resolver.query} for more information on the parameters.""" - global default_resolver - if default_resolver is None: - default_resolver = Resolver() - return default_resolver.query(qname, rdtype, rdclass, tcp) + return get_default_resolver().query(qname, rdtype, rdclass, tcp) + +def zone_for_name(name, rdclass=dns.rdataclass.IN, tcp=False, resolver=None): + """Find the name of the zone which contains the specified name. + + @param name: the query name + @type name: absolute dns.name.Name object or string + @ivar rdclass: The query class + @type rdclass: int + @param tcp: use TCP to make the query (default is False). + @type tcp: bool + @param resolver: the resolver to use + @type resolver: dns.resolver.Resolver object or None + @rtype: dns.name.Name""" + + if isinstance(name, str): + name = dns.name.from_text(name, dns.name.root) + if resolver is None: + resolver = get_default_resolver() + if not name.is_absolute(): + raise NotAbsolute, name + while len(name) > 0: + try: + answer = resolver.query(name, dns.rdatatype.SOA, rdclass, tcp) + return name + except (dns.resolver.NXDOMAIN, dns.resolver.NoAnswer): + name = dns.name.Name(name[1:]) + raise NoRootSoa diff --git a/tests/resolver.py b/tests/resolver.py index 0b17d0b2..cd644ec6 100644 --- a/tests/resolver.py +++ b/tests/resolver.py @@ -77,5 +77,29 @@ class ResolverTestCase(unittest.TestCase): self.failUnless(cache.get((name, dns.rdatatype.A, dns.rdataclass.IN)) is None) + def testZoneForName1(self): + name = dns.name.from_text('www.dnspython.org.') + ezname = dns.name.from_text('dnspython.org.') + zname = dns.resolver.zone_for_name(name) + self.failUnless(zname == ezname) + + def testZoneForName2(self): + name = dns.name.from_text('a.b.www.dnspython.org.') + ezname = dns.name.from_text('dnspython.org.') + zname = dns.resolver.zone_for_name(name) + self.failUnless(zname == ezname) + + def testZoneForName3(self): + name = dns.name.from_text('dnspython.org.') + ezname = dns.name.from_text('dnspython.org.') + zname = dns.resolver.zone_for_name(name) + self.failUnless(zname == ezname) + + def testZoneForName4(self): + def bad(): + name = dns.name.from_text('dnspython.org', None) + zname = dns.resolver.zone_for_name(name) + self.failUnlessRaises(dns.resolver.NotAbsolute, bad) + if __name__ == '__main__': unittest.main()