]> git.ipfire.org Git - thirdparty/dnspython.git/commitdiff
Re-enable TLS async tests.
authorBob Halley <halley@dnspython.org>
Fri, 12 Jun 2020 15:01:33 +0000 (08:01 -0700)
committerBob Halley <halley@dnspython.org>
Fri, 12 Jun 2020 15:01:33 +0000 (08:01 -0700)
tests/test_async.py

index ef07bb14a8529ed92d62984c17cf371d76de5673..ed75b9c10753dfae5a07bc8327f3cda31a2f863c 100644 (file)
@@ -28,6 +28,16 @@ import dns.rdataclass
 import dns.rdatatype
 import dns.resolver
 
+
+# Some tests require TLS so skip those if it's not there.
+from dns.query import ssl
+try:
+    ssl.create_default_context()
+    _ssl_available = True
+except Exception:
+    _ssl_available = False
+
+
 # Some tests require the internet to be available to run, so let's
 # skip those if it's not there.
 _network_available = True
@@ -36,6 +46,7 @@ try:
 except socket.gaierror:
     _network_available = False
 
+
 @unittest.skipIf(not _network_available, "Internet not reachable")
 class AsyncTests(unittest.TestCase):
 
@@ -157,33 +168,40 @@ class AsyncTests(unittest.TestCase):
         self.assertTrue('8.8.8.8' in seen)
         self.assertTrue('8.8.4.4' in seen)
 
-    # def testQueryTLS(self):
-    #     qname = dns.name.from_text('dns.google.')
-    #     async def run():
-    #         q = dns.message.make_query(qname, dns.rdatatype.A)
-    #         return await dns.asyncquery.stream(q, '8.8.8.8', True)
-    #     response = self.async_run(run)
-    #     rrs = response.get_rrset(response.answer, qname,
-    #                              dns.rdataclass.IN, dns.rdatatype.A)
-    #     self.assertTrue(rrs is not None)
-    #     seen = set([rdata.address for rdata in rrs])
-    #     self.assertTrue('8.8.8.8' in seen)
-    #     self.assertTrue('8.8.4.4' in seen)
-
-    # def testQueryTLSWithSocket(self):
-    #     qname = dns.name.from_text('dns.google.')
-    #     async def run():
-    #         async with await trio.open_ssl_over_tcp_stream('8.8.8.8',
-    #                                                        853) as s:
-    #             q = dns.message.make_query(qname, dns.rdatatype.A)
-    #             return await dns.asyncquery.stream(q, '8.8.8.8', stream=s)
-    #     response = self.async_run(run)
-    #     rrs = response.get_rrset(response.answer, qname,
-    #                              dns.rdataclass.IN, dns.rdatatype.A)
-    #     self.assertTrue(rrs is not None)
-    #     seen = set([rdata.address for rdata in rrs])
-    #     self.assertTrue('8.8.8.8' in seen)
-    #     self.assertTrue('8.8.4.4' in seen)
+    @unittest.skipIf(not _ssl_available, "SSL not available")
+    def testQueryTLS(self):
+        qname = dns.name.from_text('dns.google.')
+        async def run():
+            q = dns.message.make_query(qname, dns.rdatatype.A)
+            return await dns.asyncquery.tls(q, '8.8.8.8')
+        response = self.async_run(run)
+        rrs = response.get_rrset(response.answer, qname,
+                                 dns.rdataclass.IN, dns.rdatatype.A)
+        self.assertTrue(rrs is not None)
+        seen = set([rdata.address for rdata in rrs])
+        self.assertTrue('8.8.8.8' in seen)
+        self.assertTrue('8.8.4.4' in seen)
+
+    @unittest.skipIf(not _ssl_available, "SSL not available")
+    def testQueryTLSWithSocket(self):
+        qname = dns.name.from_text('dns.google.')
+        async def run():
+            ssl_context = ssl.create_default_context()
+            ssl_context.check_hostname = False
+            async with await self.backend.make_socket(socket.AF_INET,
+                                                      socket.SOCK_STREAM, 0,
+                                                      None,
+                                                      ('8.8.8.8', 853), None,
+                                                      ssl_context, None) as s:
+                q = dns.message.make_query(qname, dns.rdatatype.A)
+                return await dns.asyncquery.tls(q, '8.8.8.8', sock=s)
+        response = self.async_run(run)
+        rrs = response.get_rrset(response.answer, qname,
+                                 dns.rdataclass.IN, dns.rdatatype.A)
+        self.assertTrue(rrs is not None)
+        seen = set([rdata.address for rdata in rrs])
+        self.assertTrue('8.8.8.8' in seen)
+        self.assertTrue('8.8.4.4' in seen)
 
     def testQueryUDPFallback(self):
         qname = dns.name.from_text('.')