]> git.ipfire.org Git - thirdparty/dnspython.git/commitdiff
Fix leaked socket in testQueryTLSWithSocket.
authorBrian Wellington <bwelling@xbill.org>
Mon, 8 Jun 2020 18:01:44 +0000 (11:01 -0700)
committerBrian Wellington <bwelling@xbill.org>
Mon, 8 Jun 2020 18:01:44 +0000 (11:01 -0700)
tests/test_query.py

index e031cfd19325ff6d6f2d287a43ef978df7862757..b9699d2738480413ce0f81b684f20a928046bdd6 100644 (file)
@@ -103,20 +103,20 @@ class QueryTests(unittest.TestCase):
 
     @unittest.skipUnless(have_ssl, "No SSL support")
     def testQueryTLSWithSocket(self):
-        with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
-            s.connect(('8.8.8.8', 853))
+        with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as base_s:
+            base_s.connect(('8.8.8.8', 853))
             ctx = ssl.create_default_context()
-            s = ctx.wrap_socket(s, server_hostname='dns.google')
-            s.setblocking(0)
-            qname = dns.name.from_text('dns.google.')
-            q = dns.message.make_query(qname, dns.rdatatype.A)
-            response = dns.query.tls(q, None, sock=s)
-            rrs = response.get_rrset(response.answer, qname,
-                                     dns.rdataclass.IN, dns.rdatatype.A)
-            self.assertTrue(rrs is not None)
-            seen = set([rdata.address for rdata in rrs])
-            self.assertTrue('8.8.8.8' in seen)
-            self.assertTrue('8.8.4.4' in seen)
+            with ctx.wrap_socket(base_s, server_hostname='dns.google') as s:
+                s.setblocking(0)
+                qname = dns.name.from_text('dns.google.')
+                q = dns.message.make_query(qname, dns.rdatatype.A)
+                response = dns.query.tls(q, None, sock=s)
+                rrs = response.get_rrset(response.answer, qname,
+                                         dns.rdataclass.IN, dns.rdatatype.A)
+                self.assertTrue(rrs is not None)
+                seen = set([rdata.address for rdata in rrs])
+                self.assertTrue('8.8.8.8' in seen)
+                self.assertTrue('8.8.4.4' in seen)
 
     def testQueryUDPFallback(self):
         qname = dns.name.from_text('.')