]> git.ipfire.org Git - thirdparty/dnspython.git/commitdiff
The Tudoor fix should not eat valid Truncated exceptions [#1053] (#1054)
authorBob Halley <halley@dnspython.org>
Sun, 18 Feb 2024 18:27:43 +0000 (10:27 -0800)
committerGitHub <noreply@github.com>
Sun, 18 Feb 2024 18:27:43 +0000 (10:27 -0800)
* The Tudoor fix should not eat valid Truncated exceptions [##1053]

* Make logic more readable

dns/asyncquery.py
dns/query.py
tests/test_async.py
tests/test_query.py

index 94cb24133172eb0b4c1ee8003567c7acf037756d..4d9ab9ae49385e83515143ced8a04b01938fcab1 100644 (file)
@@ -151,6 +151,16 @@ async def receive_udp(
                 ignore_trailing=ignore_trailing,
                 raise_on_truncation=raise_on_truncation,
             )
+        except dns.message.Truncated as e:
+            # See the comment in query.py for details.
+            if (
+                ignore_errors
+                and query is not None
+                and not query.is_response(e.message())
+            ):
+                continue
+            else:
+                raise
         except Exception:
             if ignore_errors:
                 continue
index 06d186c7e4318e76133a2dc305e3137498ac36e1..384bf31e388f0f09f4d2e6696f038c0ee4d1f150 100644 (file)
@@ -618,6 +618,20 @@ def receive_udp(
                 ignore_trailing=ignore_trailing,
                 raise_on_truncation=raise_on_truncation,
             )
+        except dns.message.Truncated as e:
+            # If we got Truncated and not FORMERR, we at least got the header with TC
+            # set, and very likely the question section, so we'll re-raise if the
+            # message seems to be a response as we need to know when truncation happens.
+            # We need to check that it seems to be a response as we don't want a random
+            # injected message with TC set to cause us to bail out.
+            if (
+                ignore_errors
+                and query is not None
+                and not query.is_response(e.message())
+            ):
+                continue
+            else:
+                raise
         except Exception:
             if ignore_errors:
                 continue
index ba2078cde21f4f7621123c72eb42913893076378..9373548d7682a81fca688b0568cd00c7bc530304 100644 (file)
@@ -705,7 +705,11 @@ class IgnoreErrors(unittest.TestCase):
         from2,
         ignore_unexpected=True,
         ignore_errors=True,
+        raise_on_truncation=False,
+        good_r=None,
     ):
+        if good_r is None:
+            good_r = self.good_r
         s = MockSock(wire1, from1, wire2, from2)
         (r, when, _) = await dns.asyncquery.receive_udp(
             s,
@@ -713,9 +717,10 @@ class IgnoreErrors(unittest.TestCase):
             time.time() + 2,
             ignore_unexpected=ignore_unexpected,
             ignore_errors=ignore_errors,
+            raise_on_truncation=raise_on_truncation,
             query=self.q,
         )
-        self.assertEqual(r, self.good_r)
+        self.assertEqual(r, good_r)
 
     def test_good_mock(self):
         async def run():
@@ -802,6 +807,59 @@ class IgnoreErrors(unittest.TestCase):
 
         self.async_run(run)
 
+    def test_good_wire_with_truncation_flag_and_no_truncation_raise(self):
+        async def run():
+            tc_r = dns.message.make_response(self.q)
+            tc_r.flags |= dns.flags.TC
+            tc_r_wire = tc_r.to_wire()
+            await self.mock_receive(
+                tc_r_wire, ("127.0.0.1", 53), None, None, good_r=tc_r
+            )
+
+        self.async_run(run)
+
+    def test_good_wire_with_truncation_flag_and_truncation_raise(self):
+        async def agood():
+            tc_r = dns.message.make_response(self.q)
+            tc_r.flags |= dns.flags.TC
+            tc_r_wire = tc_r.to_wire()
+            await self.mock_receive(
+                tc_r_wire, ("127.0.0.1", 53), None, None, raise_on_truncation=True
+            )
+
+        def good():
+            self.async_run(agood)
+
+        self.assertRaises(dns.message.Truncated, good)
+
+    def test_wrong_id_wire_with_truncation_flag_and_no_truncation_raise(self):
+        async def run():
+            bad_r = dns.message.make_response(self.q)
+            bad_r.id += 1
+            bad_r.flags |= dns.flags.TC
+            bad_r_wire = bad_r.to_wire()
+            await self.mock_receive(
+                bad_r_wire, ("127.0.0.1", 53), self.good_r_wire, ("127.0.0.1", 53)
+            )
+
+        self.async_run(run)
+
+    def test_wrong_id_wire_with_truncation_flag_and_truncation_raise(self):
+        async def run():
+            bad_r = dns.message.make_response(self.q)
+            bad_r.id += 1
+            bad_r.flags |= dns.flags.TC
+            bad_r_wire = bad_r.to_wire()
+            await self.mock_receive(
+                bad_r_wire,
+                ("127.0.0.1", 53),
+                self.good_r_wire,
+                ("127.0.0.1", 53),
+                raise_on_truncation=True,
+            )
+
+        self.async_run(run)
+
     def test_bad_wire_not_ignored(self):
         bad_r = dns.message.make_response(self.q)
         bad_r.id += 1
index 1039a14eb2e6200a4fc7c2d545171841a018977d..62007e857b594e2c7a0810b17cf190d6573dc3d5 100644 (file)
@@ -29,6 +29,7 @@ except Exception:
     have_ssl = False
 
 import dns.exception
+import dns.flags
 import dns.inet
 import dns.message
 import dns.name
@@ -706,7 +707,11 @@ class IgnoreErrors(unittest.TestCase):
         from2,
         ignore_unexpected=True,
         ignore_errors=True,
+        raise_on_truncation=False,
+        good_r=None,
     ):
+        if good_r is None:
+            good_r = self.good_r
         s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
         try:
             with mock_udp_recv(wire1, from1, wire2, from2):
@@ -716,9 +721,10 @@ class IgnoreErrors(unittest.TestCase):
                     time.time() + 2,
                     ignore_unexpected=ignore_unexpected,
                     ignore_errors=ignore_errors,
+                    raise_on_truncation=raise_on_truncation,
                     query=self.q,
                 )
-                self.assertEqual(r, self.good_r)
+                self.assertEqual(r, good_r)
         finally:
             s.close()
 
@@ -787,6 +793,42 @@ class IgnoreErrors(unittest.TestCase):
             bad_r_wire[:10], ("127.0.0.1", 53), self.good_r_wire, ("127.0.0.1", 53)
         )
 
+    def test_good_wire_with_truncation_flag_and_no_truncation_raise(self):
+        tc_r = dns.message.make_response(self.q)
+        tc_r.flags |= dns.flags.TC
+        tc_r_wire = tc_r.to_wire()
+        self.mock_receive(tc_r_wire, ("127.0.0.1", 53), None, None, good_r=tc_r)
+
+    def test_good_wire_with_truncation_flag_and_truncation_raise(self):
+        def good():
+            tc_r = dns.message.make_response(self.q)
+            tc_r.flags |= dns.flags.TC
+            tc_r_wire = tc_r.to_wire()
+            self.mock_receive(
+                tc_r_wire, ("127.0.0.1", 53), None, None, raise_on_truncation=True
+            )
+
+        self.assertRaises(dns.message.Truncated, good)
+
+    def test_wrong_id_wire_with_truncation_flag_and_no_truncation_raise(self):
+        bad_r = dns.message.make_response(self.q)
+        bad_r.id += 1
+        bad_r.flags |= dns.flags.TC
+        bad_r_wire = bad_r.to_wire()
+        self.mock_receive(
+            bad_r_wire, ("127.0.0.1", 53), self.good_r_wire, ("127.0.0.1", 53)
+        )
+
+    def test_wrong_id_wire_with_truncation_flag_and_truncation_raise(self):
+        bad_r = dns.message.make_response(self.q)
+        bad_r.id += 1
+        bad_r.flags |= dns.flags.TC
+        bad_r_wire = bad_r.to_wire()
+        self.mock_receive(
+            bad_r_wire, ("127.0.0.1", 53), self.good_r_wire, ("127.0.0.1", 53),
+            raise_on_truncation=True
+        )
+
     def test_bad_wire_not_ignored(self):
         bad_r = dns.message.make_response(self.q)
         bad_r.id += 1