]> git.ipfire.org Git - thirdparty/dnspython.git/commitdiff
Wrap exceptions from rdata from_text() and from_wire(). 553/head
authorBob Halley <halley@dnspython.org>
Mon, 27 Jul 2020 00:48:14 +0000 (17:48 -0700)
committerBob Halley <halley@dnspython.org>
Mon, 27 Jul 2020 00:48:14 +0000 (17:48 -0700)
dns/exception.py
dns/rdata.py
tests/test_rdata.py

index 8f1d48883090668de05827a3e059b8a4de86dae8..9486f4507421ce115d32e9e2dd0c42cb50035bdd 100644 (file)
@@ -126,3 +126,17 @@ class Timeout(DNSException):
     """The DNS operation timed out."""
     supp_kwargs = {'timeout'}
     fmt = "The DNS operation timed out after {timeout} seconds"
+
+
+class ExceptionWrapper:
+    def __init__(self, exception_class):
+        self.exception_class = exception_class
+
+    def __enter__(self):
+        return self
+
+    def __exit__(self, exc_type, exc_val, exc_tb):
+        if exc_type is not None and not isinstance(exc_val,
+                                                   self.exception_class):
+            raise self.exception_class() from exc_val
+        return False
index 2d08dcc9bb25e6eea7b56c78433f7b72e08358cf..0daa08dab05fa85da2a2e6b9a73d171f570b2e98 100644 (file)
@@ -459,35 +459,35 @@ def from_text(rdclass, rdtype, tok, origin=None, relativize=True,
     Returns an instance of the chosen Rdata subclass.
 
     """
-
     if isinstance(tok, str):
         tok = dns.tokenizer.Tokenizer(tok, idna_codec=idna_codec)
     rdclass = dns.rdataclass.RdataClass.make(rdclass)
     rdtype = dns.rdatatype.RdataType.make(rdtype)
     cls = get_rdata_class(rdclass, rdtype)
-    rdata = None
-    if cls != GenericRdata:
-        # peek at first token
-        token = tok.get()
-        tok.unget(token)
-        if token.is_identifier() and \
-           token.value == r'\#':
-            #
-            # Known type using the generic syntax.  Extract the
-            # wire form from the generic syntax, and then run
-            # from_wire on it.
-            #
-            grdata = GenericRdata.from_text(rdclass, rdtype, tok, origin,
-                                            relativize, relativize_to)
-            rdata = from_wire(rdclass, rdtype, grdata.data, 0, len(grdata.data),
-                              origin)
-    if rdata is None:
-        rdata = cls.from_text(rdclass, rdtype, tok, origin, relativize,
-                              relativize_to)
-    token = tok.get_eol_as_token()
-    if token.comment is not None:
-        object.__setattr__(rdata, 'rdcomment', token.comment)
-    return rdata
+    with dns.exception.ExceptionWrapper(dns.exception.SyntaxError):
+        rdata = None
+        if cls != GenericRdata:
+            # peek at first token
+            token = tok.get()
+            tok.unget(token)
+            if token.is_identifier() and \
+               token.value == r'\#':
+                #
+                # Known type using the generic syntax.  Extract the
+                # wire form from the generic syntax, and then run
+                # from_wire on it.
+                #
+                grdata = GenericRdata.from_text(rdclass, rdtype, tok, origin,
+                                                relativize, relativize_to)
+                rdata = from_wire(rdclass, rdtype, grdata.data, 0,
+                                  len(grdata.data), origin)
+        if rdata is None:
+            rdata = cls.from_text(rdclass, rdtype, tok, origin, relativize,
+                                  relativize_to)
+        token = tok.get_eol_as_token()
+        if token.comment is not None:
+            object.__setattr__(rdata, 'rdcomment', token.comment)
+        return rdata
 
 
 def from_wire_parser(rdclass, rdtype, parser, origin=None):
@@ -517,7 +517,8 @@ def from_wire_parser(rdclass, rdtype, parser, origin=None):
     rdclass = dns.rdataclass.RdataClass.make(rdclass)
     rdtype = dns.rdatatype.RdataType.make(rdtype)
     cls = get_rdata_class(rdclass, rdtype)
-    return cls.from_wire_parser(rdclass, rdtype, parser, origin)
+    with dns.exception.ExceptionWrapper(dns.exception.FormError):
+        return cls.from_wire_parser(rdclass, rdtype, parser, origin)
 
 
 def from_wire(rdclass, rdtype, wire, current, rdlen, origin=None):
index 8d9937e19bf273ad259f8022d5af4af08dede26b..090ca9b868bb8cbe52a3cc6e2eabcdc5f7453724 100644 (file)
@@ -385,10 +385,12 @@ class RdataTestCase(unittest.TestCase):
         self.equal_wks('10.0.0.1 udp ( domain )', '10.0.0.1 17 ( 53 )')
 
     def test_misc_bad_WKS_text(self):
-        def bad():
+        try:
             dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.WKS,
                                 '10.0.0.1 132 ( domain )')
-        self.assertRaises(NotImplementedError, bad)
+            self.assertTrue(False)  # should not happen
+        except dns.exception.SyntaxError as e:
+            self.assertIsInstance(e.__cause__, NotImplementedError)
 
     def test_GPOS_float_converters(self):
         rd = dns.rdata.from_text('in', 'gpos', '49 0 0')
@@ -426,7 +428,7 @@ class RdataTestCase(unittest.TestCase):
                     '"0" "-180.1" "0"',
                     ]
         for gpos in bad_gpos:
-            with self.assertRaises(dns.exception.FormError):
+            with self.assertRaises(dns.exception.SyntaxError):
                 dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.GPOS, gpos)
 
     def test_bad_GPOS_wire(self):
@@ -556,11 +558,11 @@ class RdataTestCase(unittest.TestCase):
     def test_CERT_algorithm(self):
         rd = dns.rdata.from_text('in', 'cert', 'SPKI 1 0 Ym9ndXM=')
         self.assertEqual(rd.algorithm, 0)
-        with self.assertRaises(ValueError):
+        with self.assertRaises(dns.exception.SyntaxError):
             dns.rdata.from_text('in', 'cert', 'SPKI 1 -1 Ym9ndXM=')
-        with self.assertRaises(ValueError):
+        with self.assertRaises(dns.exception.SyntaxError):
             dns.rdata.from_text('in', 'cert', 'SPKI 1 256 Ym9ndXM=')
-        with self.assertRaises(ValueError):
+        with self.assertRaises(dns.exception.SyntaxError):
             dns.rdata.from_text('in', 'cert', 'SPKI 1 BOGUS Ym9ndXM=')
 
     def test_bad_URI_text(self):
@@ -603,16 +605,24 @@ class RdataTestCase(unittest.TestCase):
                                 ' Ym9ndXM=')
 
     def test_bad_sigtime(self):
-        with self.assertRaises(dns.rdtypes.ANY.RRSIG.BadSigTime):
+        try:
             dns.rdata.from_text('in', 'rrsig',
                                 'NSEC 1 3 3600 ' +
                                 '202001010000000 20030101000000 ' +
                                 '2143 foo Ym9ndXM=')
-        with self.assertRaises(dns.rdtypes.ANY.RRSIG.BadSigTime):
+            self.assertTrue(False)  # should not happen
+        except dns.exception.SyntaxError as e:
+            self.assertIsInstance(e.__cause__,
+                                  dns.rdtypes.ANY.RRSIG.BadSigTime)
+        try:
             dns.rdata.from_text('in', 'rrsig',
                                 'NSEC 1 3 3600 ' +
                                 '20200101000000 2003010100000 ' +
                                 '2143 foo Ym9ndXM=')
+            self.assertTrue(False)  # should not happen
+        except dns.exception.SyntaxError as e:
+            self.assertIsInstance(e.__cause__,
+                                  dns.rdtypes.ANY.RRSIG.BadSigTime)
 
     def test_empty_TXT(self):
         # hit too long