]> git.ipfire.org Git - thirdparty/dnspython.git/commitdiff
test IgnoreErrors
authorBob Halley <halley@dnspython.org>
Fri, 16 Feb 2024 15:14:49 +0000 (07:14 -0800)
committerBob Halley <halley@dnspython.org>
Fri, 16 Feb 2024 15:15:14 +0000 (07:15 -0800)
(cherry picked from commit ac6763f1018458835201b38cae848e4d261f3e5c)

tests/test_query.py

index 1116b2d128842b9bda3d9a4613f0415005c274f2..a47daa459a0a93656c0f705ebcf0c8633ad933c2 100644 (file)
@@ -15,6 +15,7 @@
 # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
 # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
 
+import contextlib
 import socket
 import sys
 import time
@@ -32,6 +33,7 @@ import dns.inet
 import dns.message
 import dns.name
 import dns.query
+import dns.rcode
 import dns.rdataclass
 import dns.rdatatype
 import dns.tsigkeyring
@@ -659,3 +661,141 @@ class MiscTests(unittest.TestCase):
             dns.query._matches_destination(
                 socket.AF_INET, ("10.0.0.1", 1234), ("10.0.0.1", 1235), False
             )
+
+
+@contextlib.contextmanager
+def mock_udp_recv(wire1, from1, wire2, from2):
+    saved = dns.query._udp_recv
+    first_time = True
+
+    def mock(sock, max_size, expiration):
+        nonlocal first_time
+        if first_time:
+            first_time = False
+            return wire1, from1
+        else:
+            return wire2, from2
+
+    try:
+        dns.query._udp_recv = mock
+        yield None
+    finally:
+        dns.query._udp_recv = saved
+
+
+class IgnoreErrors(unittest.TestCase):
+    def setUp(self):
+        self.q = dns.message.make_query("example.", "A")
+        self.good_r = dns.message.make_response(self.q)
+        self.good_r.set_rcode(dns.rcode.NXDOMAIN)
+        self.good_r_wire = self.good_r.to_wire()
+
+    def mock_receive(
+        self,
+        wire1,
+        from1,
+        wire2,
+        from2,
+        ignore_unexpected=True,
+        ignore_errors=True,
+    ):
+        s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
+        try:
+            with mock_udp_recv(wire1, from1, wire2, from2):
+                (r, when) = dns.query.receive_udp(
+                    s,
+                    ("127.0.0.1", 53),
+                    time.time() + 2,
+                    ignore_unexpected=ignore_unexpected,
+                    ignore_errors=ignore_errors,
+                    query=self.q,
+                )
+                self.assertEqual(r, self.good_r)
+        finally:
+            s.close()
+
+    def test_good_mock(self):
+        self.mock_receive(self.good_r_wire, ("127.0.0.1", 53), None, None)
+
+    def test_bad_address(self):
+        self.mock_receive(
+            self.good_r_wire, ("127.0.0.2", 53), self.good_r_wire, ("127.0.0.1", 53)
+        )
+
+    def test_bad_address_not_ignored(self):
+        def bad():
+            self.mock_receive(
+                self.good_r_wire,
+                ("127.0.0.2", 53),
+                self.good_r_wire,
+                ("127.0.0.1", 53),
+                ignore_unexpected=False,
+            )
+
+        self.assertRaises(dns.query.UnexpectedSource, bad)
+
+    def test_bad_id(self):
+        bad_r = dns.message.make_response(self.q)
+        bad_r.id += 1
+        bad_r_wire = bad_r.to_wire()
+        self.mock_receive(
+            bad_r_wire, ("127.0.0.1", 53), self.good_r_wire, ("127.0.0.1", 53)
+        )
+
+    def test_bad_id_not_ignored(self):
+        bad_r = dns.message.make_response(self.q)
+        bad_r.id += 1
+        bad_r_wire = bad_r.to_wire()
+
+        def bad():
+            (r, wire) = self.mock_receive(
+                bad_r_wire,
+                ("127.0.0.1", 53),
+                self.good_r_wire,
+                ("127.0.0.1", 53),
+                ignore_errors=False,
+            )
+
+        self.assertRaises(AssertionError, bad)
+
+    def test_bad_wire(self):
+        bad_r = dns.message.make_response(self.q)
+        bad_r.id += 1
+        bad_r_wire = bad_r.to_wire()
+        self.mock_receive(
+            bad_r_wire[:10], ("127.0.0.1", 53), self.good_r_wire, ("127.0.0.1", 53)
+        )
+
+    def test_bad_wire_not_ignored(self):
+        bad_r = dns.message.make_response(self.q)
+        bad_r.id += 1
+        bad_r_wire = bad_r.to_wire()
+
+        def bad():
+            self.mock_receive(
+                bad_r_wire[:10],
+                ("127.0.0.1", 53),
+                self.good_r_wire,
+                ("127.0.0.1", 53),
+                ignore_errors=False,
+            )
+
+        self.assertRaises(dns.message.ShortHeader, bad)
+
+    def test_trailing_wire(self):
+        wire = self.good_r_wire + b"abcd"
+        self.mock_receive(wire, ("127.0.0.1", 53), self.good_r_wire, ("127.0.0.1", 53))
+
+    def test_trailing_wire_not_ignored(self):
+        wire = self.good_r_wire + b"abcd"
+
+        def bad():
+            self.mock_receive(
+                wire,
+                ("127.0.0.1", 53),
+                self.good_r_wire,
+                ("127.0.0.1", 53),
+                ignore_errors=False,
+            )
+
+        self.assertRaises(dns.message.TrailingJunk, bad)