# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+import contextlib
import socket
import sys
import time
import dns.message
import dns.name
import dns.query
+import dns.rcode
import dns.rdataclass
import dns.rdatatype
import dns.tsigkeyring
dns.query._matches_destination(
socket.AF_INET, ("10.0.0.1", 1234), ("10.0.0.1", 1235), False
)
+
+
+@contextlib.contextmanager
+def mock_udp_recv(wire1, from1, wire2, from2):
+ saved = dns.query._udp_recv
+ first_time = True
+
+ def mock(sock, max_size, expiration):
+ nonlocal first_time
+ if first_time:
+ first_time = False
+ return wire1, from1
+ else:
+ return wire2, from2
+
+ try:
+ dns.query._udp_recv = mock
+ yield None
+ finally:
+ dns.query._udp_recv = saved
+
+
+class IgnoreErrors(unittest.TestCase):
+ def setUp(self):
+ self.q = dns.message.make_query("example.", "A")
+ self.good_r = dns.message.make_response(self.q)
+ self.good_r.set_rcode(dns.rcode.NXDOMAIN)
+ self.good_r_wire = self.good_r.to_wire()
+
+ def mock_receive(
+ self,
+ wire1,
+ from1,
+ wire2,
+ from2,
+ ignore_unexpected=True,
+ ignore_errors=True,
+ ):
+ s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
+ try:
+ with mock_udp_recv(wire1, from1, wire2, from2):
+ (r, when) = dns.query.receive_udp(
+ s,
+ ("127.0.0.1", 53),
+ time.time() + 2,
+ ignore_unexpected=ignore_unexpected,
+ ignore_errors=ignore_errors,
+ query=self.q,
+ )
+ self.assertEqual(r, self.good_r)
+ finally:
+ s.close()
+
+ def test_good_mock(self):
+ self.mock_receive(self.good_r_wire, ("127.0.0.1", 53), None, None)
+
+ def test_bad_address(self):
+ self.mock_receive(
+ self.good_r_wire, ("127.0.0.2", 53), self.good_r_wire, ("127.0.0.1", 53)
+ )
+
+ def test_bad_address_not_ignored(self):
+ def bad():
+ self.mock_receive(
+ self.good_r_wire,
+ ("127.0.0.2", 53),
+ self.good_r_wire,
+ ("127.0.0.1", 53),
+ ignore_unexpected=False,
+ )
+
+ self.assertRaises(dns.query.UnexpectedSource, bad)
+
+ def test_bad_id(self):
+ bad_r = dns.message.make_response(self.q)
+ bad_r.id += 1
+ bad_r_wire = bad_r.to_wire()
+ self.mock_receive(
+ bad_r_wire, ("127.0.0.1", 53), self.good_r_wire, ("127.0.0.1", 53)
+ )
+
+ def test_bad_id_not_ignored(self):
+ bad_r = dns.message.make_response(self.q)
+ bad_r.id += 1
+ bad_r_wire = bad_r.to_wire()
+
+ def bad():
+ (r, wire) = self.mock_receive(
+ bad_r_wire,
+ ("127.0.0.1", 53),
+ self.good_r_wire,
+ ("127.0.0.1", 53),
+ ignore_errors=False,
+ )
+
+ self.assertRaises(AssertionError, bad)
+
+ def test_bad_wire(self):
+ bad_r = dns.message.make_response(self.q)
+ bad_r.id += 1
+ bad_r_wire = bad_r.to_wire()
+ self.mock_receive(
+ bad_r_wire[:10], ("127.0.0.1", 53), self.good_r_wire, ("127.0.0.1", 53)
+ )
+
+ def test_bad_wire_not_ignored(self):
+ bad_r = dns.message.make_response(self.q)
+ bad_r.id += 1
+ bad_r_wire = bad_r.to_wire()
+
+ def bad():
+ self.mock_receive(
+ bad_r_wire[:10],
+ ("127.0.0.1", 53),
+ self.good_r_wire,
+ ("127.0.0.1", 53),
+ ignore_errors=False,
+ )
+
+ self.assertRaises(dns.message.ShortHeader, bad)
+
+ def test_trailing_wire(self):
+ wire = self.good_r_wire + b"abcd"
+ self.mock_receive(wire, ("127.0.0.1", 53), self.good_r_wire, ("127.0.0.1", 53))
+
+ def test_trailing_wire_not_ignored(self):
+ wire = self.good_r_wire + b"abcd"
+
+ def bad():
+ self.mock_receive(
+ wire,
+ ("127.0.0.1", 53),
+ self.good_r_wire,
+ ("127.0.0.1", 53),
+ ignore_errors=False,
+ )
+
+ self.assertRaises(dns.message.TrailingJunk, bad)