]> git.ipfire.org Git - thirdparty/dnspython.git/commitdiff
Test passing ssl_context to tls query functions.
authorBob Halley <halley@dnspython.org>
Sat, 8 Jul 2023 19:58:12 +0000 (12:58 -0700)
committerBob Halley <halley@dnspython.org>
Sat, 8 Jul 2023 19:58:12 +0000 (12:58 -0700)
tests/test_async.py
tests/test_query.py

index d46f79e72aedeff949cbeeb3e1bae9235e6a75a4..d0f977a2147e7acb91557c8c11fceae7d116b2fd 100644 (file)
@@ -408,6 +408,28 @@ class AsyncTests(unittest.TestCase):
             self.assertTrue("8.8.8.8" in seen)
             self.assertTrue("8.8.4.4" in seen)
 
+    @unittest.skipIf(not _ssl_available, "SSL not available")
+    def testQueryTLSWithContext(self):
+        for address in query_addresses:
+            qname = dns.name.from_text("dns.google.")
+
+            async def run():
+                ssl_context = ssl.create_default_context()
+                ssl_context.check_hostname = True
+                q = dns.message.make_query(qname, dns.rdatatype.A)
+                return await dns.asyncquery.tls(
+                    q, address, timeout=2, ssl_context=ssl_context
+                )
+
+            response = self.async_run(run)
+            rrs = response.get_rrset(
+                response.answer, qname, dns.rdataclass.IN, dns.rdatatype.A
+            )
+            self.assertTrue(rrs is not None)
+            seen = set([rdata.address for rdata in rrs])
+            self.assertTrue("8.8.8.8" in seen)
+            self.assertTrue("8.8.4.4" in seen)
+
     @unittest.skipIf(not _ssl_available, "SSL not available")
     def testQueryTLSWithSocket(self):
         for address in query_addresses:
@@ -640,8 +662,8 @@ class AsyncioOnlyTests(unittest.TestCase):
 
 
 try:
-    import trio
     import sniffio
+    import trio
 
     class TrioAsyncDetectionTests(AsyncDetectionTests):
         sniff_result = "trio"
index 7f41f3409258ba42e2141187d9cd4d29f2ce2afd..1116b2d128842b9bda3d9a4613f0415005c274f2 100644 (file)
@@ -31,9 +31,9 @@ import dns.exception
 import dns.inet
 import dns.message
 import dns.name
+import dns.query
 import dns.rdataclass
 import dns.rdatatype
-import dns.query
 import dns.tsigkeyring
 import dns.zone
 import tests.util
@@ -140,6 +140,22 @@ class QueryTests(unittest.TestCase):
             self.assertTrue("8.8.8.8" in seen)
             self.assertTrue("8.8.4.4" in seen)
 
+    @unittest.skipUnless(have_ssl, "No SSL support")
+    def testQueryTLSWithContext(self):
+        for address in query_addresses:
+            qname = dns.name.from_text("dns.google.")
+            q = dns.message.make_query(qname, dns.rdatatype.A)
+            ssl_context = ssl.create_default_context()
+            ssl_context.check_hostname = False
+            response = dns.query.tls(q, address, timeout=2, ssl_context=ssl_context)
+            rrs = response.get_rrset(
+                response.answer, qname, dns.rdataclass.IN, dns.rdatatype.A
+            )
+            self.assertTrue(rrs is not None)
+            seen = set([rdata.address for rdata in rrs])
+            self.assertTrue("8.8.8.8" in seen)
+            self.assertTrue("8.8.4.4" in seen)
+
     @unittest.skipUnless(have_ssl, "No SSL support")
     def testQueryTLSWithSocket(self):
         for address in query_addresses: