From 6d17824fc423dcb709fc8c38120ccf4d60a1668b Mon Sep 17 00:00:00 2001 From: Bob Halley Date: Sat, 8 Jul 2023 12:58:12 -0700 Subject: [PATCH] Test passing ssl_context to tls query functions. --- tests/test_async.py | 24 +++++++++++++++++++++++- tests/test_query.py | 18 +++++++++++++++++- 2 files changed, 40 insertions(+), 2 deletions(-) diff --git a/tests/test_async.py b/tests/test_async.py index d46f79e7..d0f977a2 100644 --- a/tests/test_async.py +++ b/tests/test_async.py @@ -408,6 +408,28 @@ class AsyncTests(unittest.TestCase): self.assertTrue("8.8.8.8" in seen) self.assertTrue("8.8.4.4" in seen) + @unittest.skipIf(not _ssl_available, "SSL not available") + def testQueryTLSWithContext(self): + for address in query_addresses: + qname = dns.name.from_text("dns.google.") + + async def run(): + ssl_context = ssl.create_default_context() + ssl_context.check_hostname = True + q = dns.message.make_query(qname, dns.rdatatype.A) + return await dns.asyncquery.tls( + q, address, timeout=2, ssl_context=ssl_context + ) + + response = self.async_run(run) + rrs = response.get_rrset( + response.answer, qname, dns.rdataclass.IN, dns.rdatatype.A + ) + self.assertTrue(rrs is not None) + seen = set([rdata.address for rdata in rrs]) + self.assertTrue("8.8.8.8" in seen) + self.assertTrue("8.8.4.4" in seen) + @unittest.skipIf(not _ssl_available, "SSL not available") def testQueryTLSWithSocket(self): for address in query_addresses: @@ -640,8 +662,8 @@ class AsyncioOnlyTests(unittest.TestCase): try: - import trio import sniffio + import trio class TrioAsyncDetectionTests(AsyncDetectionTests): sniff_result = "trio" diff --git a/tests/test_query.py b/tests/test_query.py index 7f41f340..1116b2d1 100644 --- a/tests/test_query.py +++ b/tests/test_query.py @@ -31,9 +31,9 @@ import dns.exception import dns.inet import dns.message import dns.name +import dns.query import dns.rdataclass import dns.rdatatype -import dns.query import dns.tsigkeyring import dns.zone import tests.util @@ -140,6 +140,22 @@ class QueryTests(unittest.TestCase): self.assertTrue("8.8.8.8" in seen) self.assertTrue("8.8.4.4" in seen) + @unittest.skipUnless(have_ssl, "No SSL support") + def testQueryTLSWithContext(self): + for address in query_addresses: + qname = dns.name.from_text("dns.google.") + q = dns.message.make_query(qname, dns.rdatatype.A) + ssl_context = ssl.create_default_context() + ssl_context.check_hostname = False + response = dns.query.tls(q, address, timeout=2, ssl_context=ssl_context) + rrs = response.get_rrset( + response.answer, qname, dns.rdataclass.IN, dns.rdatatype.A + ) + self.assertTrue(rrs is not None) + seen = set([rdata.address for rdata in rrs]) + self.assertTrue("8.8.8.8" in seen) + self.assertTrue("8.8.4.4" in seen) + @unittest.skipUnless(have_ssl, "No SSL support") def testQueryTLSWithSocket(self): for address in query_addresses: -- 2.47.3