]> git.ipfire.org Git - thirdparty/dnspython.git/commitdiff
Further improve CVE fix coverage to 100% for sync and async.
authorBob Halley <halley@dnspython.org>
Fri, 16 Feb 2024 16:46:24 +0000 (08:46 -0800)
committerBob Halley <halley@dnspython.org>
Fri, 16 Feb 2024 16:46:24 +0000 (08:46 -0800)
tests/test_async.py
tests/test_query.py

index 4ea2301586a00039ad810e3ae6b3cf1c81ada777..ba2078cde21f4f7621123c72eb42913893076378 100644 (file)
@@ -18,7 +18,6 @@
 import asyncio
 import random
 import socket
-import sys
 import time
 import unittest
 
@@ -28,6 +27,7 @@ import dns.asyncresolver
 import dns.message
 import dns.name
 import dns.query
+import dns.rcode
 import dns.rdataclass
 import dns.rdatatype
 import dns.resolver
@@ -664,3 +664,185 @@ try:
 
 except ImportError:
     pass
+
+
+class MockSock:
+    def __init__(self, wire1, from1, wire2, from2):
+        self.family = socket.AF_INET
+        self.first_time = True
+        self.wire1 = wire1
+        self.from1 = from1
+        self.wire2 = wire2
+        self.from2 = from2
+
+    async def sendto(self, data, where, timeout):
+        return len(data)
+
+    async def recvfrom(self, bufsize, expiration):
+        if self.first_time:
+            self.first_time = False
+            return self.wire1, self.from1
+        else:
+            return self.wire2, self.from2
+
+
+class IgnoreErrors(unittest.TestCase):
+    def setUp(self):
+        self.q = dns.message.make_query("example.", "A")
+        self.good_r = dns.message.make_response(self.q)
+        self.good_r.set_rcode(dns.rcode.NXDOMAIN)
+        self.good_r_wire = self.good_r.to_wire()
+        dns.asyncbackend.set_default_backend("asyncio")
+
+    def async_run(self, afunc):
+        return asyncio.run(afunc())
+
+    async def mock_receive(
+        self,
+        wire1,
+        from1,
+        wire2,
+        from2,
+        ignore_unexpected=True,
+        ignore_errors=True,
+    ):
+        s = MockSock(wire1, from1, wire2, from2)
+        (r, when, _) = await dns.asyncquery.receive_udp(
+            s,
+            ("127.0.0.1", 53),
+            time.time() + 2,
+            ignore_unexpected=ignore_unexpected,
+            ignore_errors=ignore_errors,
+            query=self.q,
+        )
+        self.assertEqual(r, self.good_r)
+
+    def test_good_mock(self):
+        async def run():
+            await self.mock_receive(self.good_r_wire, ("127.0.0.1", 53), None, None)
+
+        self.async_run(run)
+
+    def test_bad_address(self):
+        async def run():
+            await self.mock_receive(
+                self.good_r_wire, ("127.0.0.2", 53), self.good_r_wire, ("127.0.0.1", 53)
+            )
+
+        self.async_run(run)
+
+    def test_bad_address_not_ignored(self):
+        async def abad():
+            await self.mock_receive(
+                self.good_r_wire,
+                ("127.0.0.2", 53),
+                self.good_r_wire,
+                ("127.0.0.1", 53),
+                ignore_unexpected=False,
+            )
+
+        def bad():
+            self.async_run(abad)
+
+        self.assertRaises(dns.query.UnexpectedSource, bad)
+
+    def test_not_response_not_ignored_udp_level(self):
+        async def abad():
+            bad_r = dns.message.make_response(self.q)
+            bad_r.id += 1
+            bad_r_wire = bad_r.to_wire()
+            s = MockSock(
+                bad_r_wire, ("127.0.0.1", 53), self.good_r_wire, ("127.0.0.1", 53)
+            )
+            await dns.asyncquery.udp(self.good_r, "127.0.0.1", sock=s)
+
+        def bad():
+            self.async_run(abad)
+
+        self.assertRaises(dns.query.BadResponse, bad)
+
+    def test_bad_id(self):
+        async def run():
+            bad_r = dns.message.make_response(self.q)
+            bad_r.id += 1
+            bad_r_wire = bad_r.to_wire()
+            await self.mock_receive(
+                bad_r_wire, ("127.0.0.1", 53), self.good_r_wire, ("127.0.0.1", 53)
+            )
+
+        self.async_run(run)
+
+    def test_bad_id_not_ignored(self):
+        bad_r = dns.message.make_response(self.q)
+        bad_r.id += 1
+        bad_r_wire = bad_r.to_wire()
+
+        async def abad():
+            (r, wire) = await self.mock_receive(
+                bad_r_wire,
+                ("127.0.0.1", 53),
+                self.good_r_wire,
+                ("127.0.0.1", 53),
+                ignore_errors=False,
+            )
+
+        def bad():
+            self.async_run(abad)
+
+        self.assertRaises(AssertionError, bad)
+
+    def test_bad_wire(self):
+        async def run():
+            bad_r = dns.message.make_response(self.q)
+            bad_r.id += 1
+            bad_r_wire = bad_r.to_wire()
+            await self.mock_receive(
+                bad_r_wire[:10], ("127.0.0.1", 53), self.good_r_wire, ("127.0.0.1", 53)
+            )
+
+        self.async_run(run)
+
+    def test_bad_wire_not_ignored(self):
+        bad_r = dns.message.make_response(self.q)
+        bad_r.id += 1
+        bad_r_wire = bad_r.to_wire()
+
+        async def abad():
+            await self.mock_receive(
+                bad_r_wire[:10],
+                ("127.0.0.1", 53),
+                self.good_r_wire,
+                ("127.0.0.1", 53),
+                ignore_errors=False,
+            )
+
+        def bad():
+            self.async_run(abad)
+
+        self.assertRaises(dns.message.ShortHeader, bad)
+
+    def test_trailing_wire(self):
+        async def run():
+            wire = self.good_r_wire + b"abcd"
+            await self.mock_receive(
+                wire, ("127.0.0.1", 53), self.good_r_wire, ("127.0.0.1", 53)
+            )
+
+        self.async_run(run)
+
+    def test_trailing_wire_not_ignored(self):
+        wire = self.good_r_wire + b"abcd"
+
+        async def abad():
+            await self.mock_receive(
+                wire,
+                ("127.0.0.1", 53),
+                self.good_r_wire,
+                ("127.0.0.1", 53),
+                ignore_errors=False,
+            )
+
+        def bad():
+            self.async_run(abad)
+
+        self.assertRaises(dns.message.TrailingJunk, bad)
index a47daa459a0a93656c0f705ebcf0c8633ad933c2..1039a14eb2e6200a4fc7c2d545171841a018977d 100644 (file)
@@ -683,6 +683,14 @@ def mock_udp_recv(wire1, from1, wire2, from2):
         dns.query._udp_recv = saved
 
 
+class MockSock:
+    def __init__(self):
+        self.family = socket.AF_INET
+
+    def sendto(self, data, where):
+        return len(data)
+
+
 class IgnoreErrors(unittest.TestCase):
     def setUp(self):
         self.q = dns.message.make_query("example.", "A")
@@ -758,6 +766,19 @@ class IgnoreErrors(unittest.TestCase):
 
         self.assertRaises(AssertionError, bad)
 
+    def test_not_response_not_ignored_udp_level(self):
+        def bad():
+            bad_r = dns.message.make_response(self.q)
+            bad_r.id += 1
+            bad_r_wire = bad_r.to_wire()
+            with mock_udp_recv(
+                bad_r_wire, ("127.0.0.1", 53), self.good_r_wire, ("127.0.0.1", 53)
+            ):
+                s = MockSock()
+                dns.query.udp(self.good_r, "127.0.0.1", sock=s)
+
+        self.assertRaises(dns.query.BadResponse, bad)
+
     def test_bad_wire(self):
         bad_r = dns.message.make_response(self.q)
         bad_r.id += 1