]> git.ipfire.org Git - thirdparty/dnspython.git/commitdiff
Pass partial answer when raising DNSException, added unit tests 109/head
authorClaudio Luck <claudio.luck@gmail.com>
Wed, 27 May 2015 15:54:13 +0000 (17:54 +0200)
committerClaudio Luck <cluck@ini.uzh.ch>
Mon, 13 Jun 2016 13:51:35 +0000 (15:51 +0200)
dns/e164.py
dns/exception.py
dns/resolver.py
tests/test_resolver.py

index 2cc911cdbb5452249d149b9b53c4c83fbcec236e..8804becfdcd06039365999aeaba40ff1ea9872bb 100644 (file)
@@ -73,12 +73,13 @@ def query(number, domains, resolver=None):
     """
     if resolver is None:
         resolver = dns.resolver.get_default_resolver()
+    e_nx = dns.resolver.NXDOMAIN()
     for domain in domains:
         if isinstance(domain, string_types):
             domain = dns.name.from_text(domain)
         qname = dns.e164.from_e164(number, domain)
         try:
             return resolver.query(qname, 'NAPTR')
-        except dns.resolver.NXDOMAIN:
-            pass
-    raise dns.resolver.NXDOMAIN
+        except dns.resolver.NXDOMAIN as e:
+            e_nx += e
+    raise e_nx
index 62fbe2cb0c660e93da6c237217ffbb51b49e7027..151f58442f5242047680d077875cf300726c21f8 100644 (file)
@@ -45,8 +45,11 @@ class DNSException(Exception):
 
     def __init__(self, *args, **kwargs):
         self._check_params(*args, **kwargs)
-        self._check_kwargs(**kwargs)
-        self.kwargs = kwargs
+        if kwargs:
+            self.kwargs = self._check_kwargs(**kwargs)
+            self.msg = str(self)
+        else:
+            self.kwargs = dict()  # defined but empty for old mode exceptions
         if self.msg is None:
             # doc string is better implicit message than empty string
             self.msg = self.__doc__
@@ -68,6 +71,7 @@ class DNSException(Exception):
             assert set(kwargs.keys()) == self.supp_kwargs, \
                 'following set of keyword args is required: %s' % (
                     self.supp_kwargs)
+        return kwargs
 
     def _fmt_kwargs(self, **kwargs):
         """Format kwargs before printing them.
index bccb430d649d8a6f72d8198a1f3bb984f7a07e21..7d1fa6f86f28664045ce5d76009cc32966dfa38d 100644 (file)
@@ -51,21 +51,61 @@ if sys.platform == 'win32':
 class NXDOMAIN(dns.exception.DNSException):
 
     """The DNS query name does not exist."""
-    supp_kwargs = set(['qname'])
+    supp_kwargs = set(['qnames', 'responses'])
+    fmt = None  # we have our own __str__ implementation
+
+    def _check_kwargs(self, qnames, responses=None):
+        if not isinstance(qnames, (list, tuple, set)):
+            raise AttributeError("qnames must be a list, tuple or set")
+        if len(qnames) == 0:
+            raise AttributeError("qnames must contain at least one element")
+        if responses is None:
+            responses = {}
+        elif not isinstance(responses, dict):
+            raise AttributeError("responses must be a dict(qname=response)")
+        kwargs = dict(qnames=qnames, responses=responses)
+        return kwargs
 
     def __str__(self):
-        if 'qname' not in self.kwargs:
+        if 'qnames' not in self.kwargs:
             return super(NXDOMAIN, self).__str__()
-
-        qname = self.kwargs['qname']
-        msg = self.__doc__[:-1]
-        if isinstance(qname, (list, set)):
-            if len(qname) > 1:
-                msg = 'None of DNS query names exist'
-                qname = list(map(str, qname))
-            else:
-                qname = qname[0]
-        return "%s: %s" % (msg, (str(qname)))
+        qnames = self.kwargs['qnames']
+        if len(qnames) > 1:
+            msg = 'None of DNS query names exist'
+        else:
+            msg = self.__doc__[:-1]
+        qnames = ', '.join(map(str, qnames))
+        return "%s: %s" % (msg, qnames)
+
+    def canonical_name(self):
+        if not 'qnames' in self.kwargs:
+            raise TypeError("parametrized exception required")
+        IN = dns.rdataclass.IN
+        CNAME = dns.rdatatype.CNAME
+        cname = None
+        for qname in self.kwargs['qnames']:
+            response = self.kwargs['responses'][qname]
+            for answer in response.answer:
+                if answer.rdtype != CNAME or answer.rdclass != IN:
+                    continue
+                cname = answer.items[0].target.to_text()
+            if cname is not None:
+                return dns.name.from_text(cname)
+        return self.kwargs['qnames'][0]
+    canonical_name = property(canonical_name, doc=(
+        "Return the unresolved canonical name."))
+
+    def __add__(self, e_nx):
+        """Augment by results from another NXDOMAIN exception."""
+        qnames0 = list(self.kwargs.get('qnames', []))
+        responses0 = dict(self.kwargs.get('responses', {}))
+        responses1 = e_nx.kwargs.get('responses', {})
+        for qname1 in e_nx.kwargs.get('qnames', []):
+            if qname1 not in qnames0:
+                qnames0.append(qname1)
+            if qname1 in responses1:
+                responses0[qname1] = responses1[qname1]
+        return NXDOMAIN(qnames=qnames0, responses=responses0)
 
 
 class YXDOMAIN(dns.exception.DNSException):
@@ -862,6 +902,7 @@ class Resolver(object):
             else:
                 qnames_to_try.append(qname.concatenate(self.domain))
         all_nxdomain = True
+        nxdomain_responses = {}
         start = time.time()
         for qname in qnames_to_try:
             if self.cache:
@@ -988,11 +1029,12 @@ class Resolver(object):
                     backoff *= 2
                     time.sleep(sleep_time)
             if response.rcode() == dns.rcode.NXDOMAIN:
+                nxdomain_responses[qname] = response
                 continue
             all_nxdomain = False
             break
         if all_nxdomain:
-            raise NXDOMAIN(qname=qnames_to_try)
+            raise NXDOMAIN(qnames=qnames_to_try, responses=nxdomain_responses)
         answer = Answer(qname, rdtype, rdclass, response,
                         raise_on_no_answer)
         if self.cache:
index eb4f23ec7b9a15d6cb33fc62068db8e16fac568c..00b4949325dbcafa5a4fd9464c8ff45a733a193f 100644 (file)
@@ -60,6 +60,45 @@ example. 1 IN A 10.0.0.1
 ;ADDITIONAL
 """
 
+dangling_cname_0_message_text = """id 10000
+opcode QUERY
+rcode NOERROR
+flags QR AA RD RA
+;QUESTION
+91.11.17.172.in-addr.arpa.none. IN PTR
+;ANSWER
+;AUTHORITY
+;ADDITIONAL
+"""
+
+dangling_cname_1_message_text = """id 10001
+opcode QUERY
+rcode NOERROR
+flags QR AA RD RA
+;QUESTION
+91.11.17.172.in-addr.arpa. IN PTR
+;ANSWER
+11.17.172.in-addr.arpa. 86400 IN DNAME 11.8-22.17.172.in-addr.arpa.
+91.11.17.172.in-addr.arpa. 86400 IN CNAME 91.11.8-22.17.172.in-addr.arpa.
+;AUTHORITY
+;ADDITIONAL
+"""
+
+dangling_cname_2_message_text = """id 10002
+opcode QUERY
+rcode NOERROR
+flags QR AA RD RA
+;QUESTION
+91.11.17.172.in-addr.arpa.example. IN PTR
+;ANSWER
+91.11.17.172.in-addr.arpa.example. 86400 IN CNAME 91.11.17.172.in-addr.arpa.base.
+91.11.17.172.in-addr.arpa.base. 86400 IN CNAME 91.11.17.172.clients.example.
+91.11.17.172.clients.example. 86400 IN CNAME 91-11-17-172.dynamic.example.
+;AUTHORITY
+;ADDITIONAL
+"""
+
+
 class FakeAnswer(object):
     def __init__(self, expiration):
         self.expiration = expiration
@@ -196,5 +235,144 @@ if hasattr(select, 'poll'):
         def polling_backend(self):
             return dns.query._poll_for
 
+class NXDOMAINExceptionTestCase(unittest.TestCase):
+
+    def test_nxdomain_compatible(self):
+        n1 = dns.name.Name(('a', 'b', ''))
+        n2 = dns.name.Name(('a', 'b', 's', ''))
+        py3 = (sys.version_info[0] > 2)
+
+        try:
+            raise dns.resolver.NXDOMAIN
+        except Exception as e:
+            if not py3: self.assertTrue((e.message == e.__doc__))
+            self.assertTrue((e.args == (e.__doc__,)))
+            self.assertTrue(('kwargs' in dir(e)))
+            self.assertTrue((str(e) == e.__doc__), str(e))
+            self.assertTrue(('qnames' not in e.kwargs))
+            self.assertTrue(('responses' not in e.kwargs))
+
+        try:
+            raise dns.resolver.NXDOMAIN("errmsg")
+        except Exception as e:
+            if not py3: self.assertTrue((e.message == "errmsg"))
+            self.assertTrue((e.args == ("errmsg",)))
+            self.assertTrue(('kwargs' in dir(e)))
+            self.assertTrue((str(e) == "errmsg"), str(e))
+            self.assertTrue(('qnames' not in e.kwargs))
+            self.assertTrue(('responses' not in e.kwargs))
+
+        try:
+            raise dns.resolver.NXDOMAIN("errmsg", -1)
+        except Exception as e:
+            if not py3: self.assertTrue((e.message == ""))
+            self.assertTrue((e.args == ("errmsg", -1)))
+            self.assertTrue(('kwargs' in dir(e)))
+            self.assertTrue((str(e) == "('errmsg', -1)"), str(e))
+            self.assertTrue(('qnames' not in e.kwargs))
+            self.assertTrue(('responses' not in e.kwargs))
+
+        try:
+            raise dns.resolver.NXDOMAIN(qnames=None)
+        except Exception as e:
+            self.assertTrue((isinstance(e, AttributeError)))
+
+        try:
+            raise dns.resolver.NXDOMAIN(qnames=n1)
+        except Exception as e:
+            self.assertTrue((isinstance(e, AttributeError)))
+
+        try:
+            raise dns.resolver.NXDOMAIN(qnames=[])
+        except Exception as e:
+            self.assertTrue((isinstance(e, AttributeError)))
+
+        try:
+            raise dns.resolver.NXDOMAIN(qnames=[n1])
+        except Exception as e:
+            MSG = "The DNS query name does not exist: a.b."
+            if not py3: self.assertTrue((e.message == MSG), e.message)
+            self.assertTrue((e.args == (MSG,)), repr(e.args))
+            self.assertTrue(('kwargs' in dir(e)))
+            self.assertTrue((str(e) == MSG), str(e))
+            self.assertTrue(('qnames' in e.kwargs))
+            self.assertTrue((e.kwargs['qnames'] == [n1]))
+            self.assertTrue(('responses' in e.kwargs))
+            self.assertTrue((e.kwargs['responses'] == {}))
+
+        try:
+            raise dns.resolver.NXDOMAIN(qnames=[n2, n1])
+        except Exception as e:
+            e0 = dns.resolver.NXDOMAIN("errmsg")
+            e = e0 + e
+            MSG = "None of DNS query names exist: a.b.s., a.b."
+            if not py3: self.assertTrue((e.message == MSG), e.message)
+            self.assertTrue((e.args == (MSG,)), repr(e.args))
+            self.assertTrue(('kwargs' in dir(e)))
+            self.assertTrue((str(e) == MSG), str(e))
+            self.assertTrue(('qnames' in e.kwargs))
+            self.assertTrue((e.kwargs['qnames'] == [n2, n1]))
+            self.assertTrue(('responses' in e.kwargs))
+            self.assertTrue((e.kwargs['responses'] == {}))
+
+        try:
+            raise dns.resolver.NXDOMAIN(qnames=[n1], responses=['r1.1'])
+        except Exception as e:
+            self.assertTrue((isinstance(e, AttributeError)))
+
+        try:
+            raise dns.resolver.NXDOMAIN(qnames=[n1], responses={n1: 'r1.1'})
+        except Exception as e:
+            MSG = "The DNS query name does not exist: a.b."
+            if not py3: self.assertTrue((e.message == MSG), e.message)
+            self.assertTrue((e.args == (MSG,)), repr(e.args))
+            self.assertTrue(('kwargs' in dir(e)))
+            self.assertTrue((str(e) == MSG), str(e))
+            self.assertTrue(('qnames' in e.kwargs))
+            self.assertTrue((e.kwargs['qnames'] == [n1]))
+            self.assertTrue(('responses' in e.kwargs))
+            self.assertTrue((e.kwargs['responses'] == {n1: 'r1.1'}))
+
+    def test_nxdomain_merge(self):
+        n1 = dns.name.Name(('a', 'b', ''))
+        n2 = dns.name.Name(('a', 'b', ''))
+        n3 = dns.name.Name(('a', 'b', 'c', ''))
+        n4 = dns.name.Name(('a', 'b', 'd', ''))
+        responses1 = {n1: 'r1.1', n2: 'r1.2', n4: 'r1.4'}
+        qnames1 = [n1, n4]   # n2 == n1
+        responses2 = {n2: 'r2.2', n3: 'r2.3'}
+        qnames2 = [n2, n3]
+        e0 = dns.resolver.NXDOMAIN()
+        e1 = dns.resolver.NXDOMAIN(qnames=qnames1, responses=responses1)
+        e2 = dns.resolver.NXDOMAIN(qnames=qnames2, responses=responses2)
+        e = e1 + e0 + e2
+        self.assertRaises(AttributeError, lambda : e0 + e0)
+        self.assertTrue(e.kwargs['qnames'] == [n1, n4, n3], repr(e.kwargs['qnames']))
+        self.assertTrue(e.kwargs['responses'][n1].startswith('r2.'))
+        self.assertTrue(e.kwargs['responses'][n2].startswith('r2.'))
+        self.assertTrue(e.kwargs['responses'][n3].startswith('r2.'))
+        self.assertTrue(e.kwargs['responses'][n4].startswith('r1.'))
+
+    def test_nxdomain_canonical_name(self):
+        cname0 = "91.11.8-22.17.172.in-addr.arpa.none."
+        cname1 = "91.11.8-22.17.172.in-addr.arpa."
+        cname2 = "91-11-17-172.dynamic.example."
+        message0 = dns.message.from_text(dangling_cname_0_message_text)
+        message1 = dns.message.from_text(dangling_cname_1_message_text)
+        message2 = dns.message.from_text(dangling_cname_2_message_text)
+        qname0 = message0.question[0].name
+        qname1 = message1.question[0].name
+        qname2 = message2.question[0].name
+        responses = {qname0: message0, qname1: message1, qname2: message2}
+        eX = dns.resolver.NXDOMAIN()
+        e0 = dns.resolver.NXDOMAIN(qnames=[qname0], responses=responses)
+        e1 = dns.resolver.NXDOMAIN(qnames=[qname0, qname1, qname2], responses=responses)
+        e2 = dns.resolver.NXDOMAIN(qnames=[qname0, qname2, qname1], responses=responses)
+        self.assertRaises(TypeError, lambda : eX.canonical_name)
+        self.assertTrue(e0.canonical_name == qname0)
+        self.assertTrue(e1.canonical_name == dns.name.from_text(cname1))
+        self.assertTrue(e2.canonical_name == dns.name.from_text(cname2))
+
+
 if __name__ == '__main__':
     unittest.main()