]> git.ipfire.org Git - thirdparty/dnspython.git/commitdiff
add dns.resolver.zone_from_name() and .get_default_resovler()
authorBob Halley <halley@dnspython.org>
Sat, 1 Oct 2005 03:18:11 +0000 (03:18 +0000)
committerBob Halley <halley@dnspython.org>
Sat, 1 Oct 2005 03:18:11 +0000 (03:18 +0000)
ChangeLog
dns/resolver.py
tests/resolver.py

index 88ec663d4189a1f81976e8e2f836aefb41a10712..226fa76821430ddddf3b6d06fc66a1bc29eaeb4f 100644 (file)
--- a/ChangeLog
+++ b/ChangeLog
@@ -1,3 +1,11 @@
+2005-10-01  Bob Halley  <halley@dnspython.org>
+
+       * dns/resolver.py: Added zone_for_name() helper, which returns
+         the name of the zone which contains the specified name.
+
+       * dns/resolver.py: Added get_default_resolver(), which returns
+         the default resolver, initializing it if necessary.
+
 2005-09-29  Bob Halley  <halley@dnspython.org>
 
        * dns/resolver.py (Resolver._compute_timeout): If time goes
index 7433d4a376dd1e230cfe659e4293502679ee0310..ca47e31f163202393f7b04300aa571c6bff7ff12 100644 (file)
@@ -52,6 +52,17 @@ class NoNameservers(dns.exception.DNSException):
     """No non-broken nameservers are available to answer the query."""
     pass
 
+class NotAbsolute(dns.exception.DNSException):
+    """Raised if an absolute domain name is required but a relative name
+    was provided."""
+    pass
+
+class NoRootSOA(dns.exception.DNSException):
+    """Raised if for some reason there is no SOA at the root name.
+    This should never happen!"""
+    pass
+
+
 class Answer(object):
     """DNS stub resolver answer
 
@@ -602,6 +613,13 @@ class Resolver(object):
 
 default_resolver = None
 
+def get_default_resolver():
+    """Get the default resolver, initializing it if necessary."""
+    global default_resolver
+    if default_resolver is None:
+        default_resolver = Resolver()
+    return default_resolver
+
 def query(qname, rdtype=dns.rdatatype.A, rdclass=dns.rdataclass.IN,
           tcp=False):
     """Query nameservers to find the answer to the question.
@@ -610,7 +628,31 @@ def query(qname, rdtype=dns.rdatatype.A, rdclass=dns.rdataclass.IN,
     object to make the query.
     @see: L{dns.resolver.Resolver.query} for more information on the
     parameters."""
-    global default_resolver
-    if default_resolver is None:
-        default_resolver = Resolver()
-    return default_resolver.query(qname, rdtype, rdclass, tcp)
+    return get_default_resolver().query(qname, rdtype, rdclass, tcp)
+
+def zone_for_name(name, rdclass=dns.rdataclass.IN, tcp=False, resolver=None):
+    """Find the name of the zone which contains the specified name.
+
+    @param name: the query name
+    @type name: absolute dns.name.Name object or string
+    @ivar rdclass: The query class
+    @type rdclass: int
+    @param tcp: use TCP to make the query (default is False).
+    @type tcp: bool
+    @param resolver: the resolver to use
+    @type resolver: dns.resolver.Resolver object or None
+    @rtype: dns.name.Name"""
+
+    if isinstance(name, str):
+        name = dns.name.from_text(name, dns.name.root)
+    if resolver is None:
+        resolver = get_default_resolver()
+    if not name.is_absolute():
+        raise NotAbsolute, name
+    while len(name) > 0:
+        try:
+            answer = resolver.query(name, dns.rdatatype.SOA, rdclass, tcp)
+            return name
+        except (dns.resolver.NXDOMAIN, dns.resolver.NoAnswer):
+            name = dns.name.Name(name[1:])
+    raise NoRootSoa
index 0b17d0b225cfa4f779959ba73c69626293fa7155..cd644ec66a75a3f76a49e84c796bc2ca6b085680 100644 (file)
@@ -77,5 +77,29 @@ class ResolverTestCase(unittest.TestCase):
         self.failUnless(cache.get((name, dns.rdatatype.A, dns.rdataclass.IN))
                         is None)
 
+    def testZoneForName1(self):
+        name = dns.name.from_text('www.dnspython.org.')
+        ezname = dns.name.from_text('dnspython.org.')
+        zname = dns.resolver.zone_for_name(name)
+        self.failUnless(zname == ezname)
+
+    def testZoneForName2(self):
+        name = dns.name.from_text('a.b.www.dnspython.org.')
+        ezname = dns.name.from_text('dnspython.org.')
+        zname = dns.resolver.zone_for_name(name)
+        self.failUnless(zname == ezname)
+
+    def testZoneForName3(self):
+        name = dns.name.from_text('dnspython.org.')
+        ezname = dns.name.from_text('dnspython.org.')
+        zname = dns.resolver.zone_for_name(name)
+        self.failUnless(zname == ezname)
+
+    def testZoneForName4(self):
+        def bad():
+            name = dns.name.from_text('dnspython.org', None)
+            zname = dns.resolver.zone_for_name(name)
+        self.failUnlessRaises(dns.resolver.NotAbsolute, bad)
+
 if __name__ == '__main__':
     unittest.main()