]> git.ipfire.org Git - thirdparty/dnspython.git/commitdiff
First pass at adding network timeouts to tests. This is for when
authorBob Halley <halley@dnspython.org>
Fri, 9 Oct 2020 16:35:14 +0000 (09:35 -0700)
committerBob Halley <halley@dnspython.org>
Fri, 9 Oct 2020 16:35:14 +0000 (09:35 -0700)
it looks like we have a network but it's not connected to the
Internet.

tests/test_async.py
tests/test_doh.py
tests/test_query.py

index 690a1ebdeb9435f7eee391a6cc990c6afd62219e..e9a26bb387d859c0cbcfa9b958af9d34d335fad2 100644 (file)
@@ -248,7 +248,7 @@ class AsyncTests(unittest.TestCase):
             qname = dns.name.from_text('dns.google.')
             async def run():
                 q = dns.message.make_query(qname, dns.rdatatype.A)
-                return await dns.asyncquery.udp(q, address)
+                return await dns.asyncquery.udp(q, address, timeout=2)
             response = self.async_run(run)
             rrs = response.get_rrset(response.answer, qname,
                                      dns.rdataclass.IN, dns.rdatatype.A)
@@ -265,7 +265,8 @@ class AsyncTests(unittest.TestCase):
                         dns.inet.af_for_address(address),
                         socket.SOCK_DGRAM) as s:
                     q = dns.message.make_query(qname, dns.rdatatype.A)
-                    return await dns.asyncquery.udp(q, address, sock=s)
+                    return await dns.asyncquery.udp(q, address, sock=s,
+                                                    timeout=2)
             response = self.async_run(run)
             rrs = response.get_rrset(response.answer, qname,
                                      dns.rdataclass.IN, dns.rdatatype.A)
@@ -279,7 +280,7 @@ class AsyncTests(unittest.TestCase):
             qname = dns.name.from_text('dns.google.')
             async def run():
                 q = dns.message.make_query(qname, dns.rdatatype.A)
-                return await dns.asyncquery.tcp(q, address)
+                return await dns.asyncquery.tcp(q, address, timeout=2)
             response = self.async_run(run)
             rrs = response.get_rrset(response.answer, qname,
                                      dns.rdataclass.IN, dns.rdatatype.A)
@@ -296,11 +297,12 @@ class AsyncTests(unittest.TestCase):
                         dns.inet.af_for_address(address),
                         socket.SOCK_STREAM, 0,
                         None,
-                        (address, 53)) as s:
+                        (address, 53), 2) as s:
                     # for basic coverage
                     await s.getsockname()
                     q = dns.message.make_query(qname, dns.rdatatype.A)
-                    return await dns.asyncquery.tcp(q, address, sock=s)
+                    return await dns.asyncquery.tcp(q, address, sock=s,
+                                                    timeout=2)
             response = self.async_run(run)
             rrs = response.get_rrset(response.answer, qname,
                                      dns.rdataclass.IN, dns.rdatatype.A)
@@ -315,7 +317,7 @@ class AsyncTests(unittest.TestCase):
             qname = dns.name.from_text('dns.google.')
             async def run():
                 q = dns.message.make_query(qname, dns.rdatatype.A)
-                return await dns.asyncquery.tls(q, address)
+                return await dns.asyncquery.tls(q, address, timeout=2)
             response = self.async_run(run)
             rrs = response.get_rrset(response.answer, qname,
                                      dns.rdataclass.IN, dns.rdatatype.A)
@@ -335,12 +337,13 @@ class AsyncTests(unittest.TestCase):
                         dns.inet.af_for_address(address),
                         socket.SOCK_STREAM, 0,
                         None,
-                        (address, 853), None,
+                        (address, 853), 2,
                         ssl_context, None) as s:
                     # for basic coverage
                     await s.getsockname()
                     q = dns.message.make_query(qname, dns.rdatatype.A)
-                    return await dns.asyncquery.tls(q, '8.8.8.8', sock=s)
+                    return await dns.asyncquery.tls(q, '8.8.8.8', sock=s,
+                                                    timeout=2)
             response = self.async_run(run)
             rrs = response.get_rrset(response.answer, qname,
                                      dns.rdataclass.IN, dns.rdatatype.A)
@@ -354,7 +357,8 @@ class AsyncTests(unittest.TestCase):
             qname = dns.name.from_text('.')
             async def run():
                 q = dns.message.make_query(qname, dns.rdatatype.DNSKEY)
-                return await dns.asyncquery.udp_with_fallback(q, address)
+                return await dns.asyncquery.udp_with_fallback(q, address,
+                                                              timeout=2)
             (_, tcp) = self.async_run(run)
             self.assertTrue(tcp)
 
@@ -363,7 +367,8 @@ class AsyncTests(unittest.TestCase):
             qname = dns.name.from_text('dns.google.')
             async def run():
                 q = dns.message.make_query(qname, dns.rdatatype.A)
-                return await dns.asyncquery.udp_with_fallback(q, address)
+                return await dns.asyncquery.udp_with_fallback(q, address,
+                                                              timeout=2)
             (_, tcp) = self.async_run(run)
             self.assertFalse(tcp)
 
index c5c05696aed64d29e2bfdbcf22f7528f753348f4..793a50060c7684840c904460ceac7edc44315fda 100644 (file)
@@ -32,6 +32,7 @@ resolver_v4_addresses = []
 resolver_v6_addresses = []
 try:
     with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s:
+        s.settimeout(4)
         s.connect(('8.8.8.8', 53))
     resolver_v4_addresses = [
         '1.1.1.1',
@@ -77,13 +78,15 @@ class DNSOverHTTPSTestCase(unittest.TestCase):
     def test_get_request(self):
         nameserver_url = random.choice(KNOWN_ANYCAST_DOH_RESOLVER_URLS)
         q = dns.message.make_query('example.com.', dns.rdatatype.A)
-        r = dns.query.https(q, nameserver_url, session=self.session, post=False)
+        r = dns.query.https(q, nameserver_url, session=self.session, post=False,
+                            timeout=4)
         self.assertTrue(q.is_response(r))
 
     def test_post_request(self):
         nameserver_url = random.choice(KNOWN_ANYCAST_DOH_RESOLVER_URLS)
         q = dns.message.make_query('example.com.', dns.rdatatype.A)
-        r = dns.query.https(q, nameserver_url, session=self.session, post=True)
+        r = dns.query.https(q, nameserver_url, session=self.session, post=True,
+                            timeout=4)
         self.assertTrue(q.is_response(r))
 
     def test_build_url_from_ip(self):
@@ -95,14 +98,14 @@ class DNSOverHTTPSTestCase(unittest.TestCase):
             # https://8.8.8.8/dns-query
             # So we're just going to do GET requests here
             r = dns.query.https(q, nameserver_ip, session=self.session,
-                                post=False)
+                                post=False, timeout=4)
 
             self.assertTrue(q.is_response(r))
         if resolver_v6_addresses:
             nameserver_ip = random.choice(resolver_v6_addresses)
             q = dns.message.make_query('example.com.', dns.rdatatype.A)
             r = dns.query.https(q, nameserver_ip, session=self.session,
-                                post=False)
+                                post=False, timeout=4)
             self.assertTrue(q.is_response(r))
 
     def test_bootstrap_address(self):
@@ -115,16 +118,17 @@ class DNSOverHTTPSTestCase(unittest.TestCase):
             # make sure CleanBrowsing's IP address will fail TLS certificate
             # check
             with self.assertRaises(SSLError):
-                dns.query.https(q, invalid_tls_url, session=self.session)
+                dns.query.https(q, invalid_tls_url, session=self.session,
+                                timeout=4)
             # use host header
             r = dns.query.https(q, valid_tls_url, session=self.session,
-                                bootstrap_address=ip)
+                                bootstrap_address=ip, timeout=4)
             self.assertTrue(q.is_response(r))
 
     def test_new_session(self):
         nameserver_url = random.choice(KNOWN_ANYCAST_DOH_RESOLVER_URLS)
         q = dns.message.make_query('example.com.', dns.rdatatype.A)
-        r = dns.query.https(q, nameserver_url)
+        r = dns.query.https(q, nameserver_url, timeout=4)
         self.assertTrue(q.is_response(r))
 
     def test_resolver(self):
index 7a1ec718a92cb1ceb6e8567329f50cbc0b468d8c..8f2b65f94b0a27f856d9b84fc8c2a1a5f83b3aee 100644 (file)
@@ -68,7 +68,7 @@ for (af, address) in ((socket.AF_INET, '8.8.8.8'),
     except Exception:
         pass
 
-keyring = dns.tsigkeyring.from_text({'name' : 'tDz6cfXXGtNivRpQ98hr6A=='})
+keyring = dns.tsigkeyring.from_text({'name': 'tDz6cfXXGtNivRpQ98hr6A=='})
 
 @unittest.skipIf(not _network_available, "Internet not reachable")
 class QueryTests(unittest.TestCase):
@@ -77,7 +77,7 @@ class QueryTests(unittest.TestCase):
         for address in query_addresses:
             qname = dns.name.from_text('dns.google.')
             q = dns.message.make_query(qname, dns.rdatatype.A)
-            response = dns.query.udp(q, address)
+            response = dns.query.udp(q, address, timeout=2)
             rrs = response.get_rrset(response.answer, qname,
                                      dns.rdataclass.IN, dns.rdatatype.A)
             self.assertTrue(rrs is not None)
@@ -92,7 +92,7 @@ class QueryTests(unittest.TestCase):
                 s.setblocking(0)
                 qname = dns.name.from_text('dns.google.')
                 q = dns.message.make_query(qname, dns.rdatatype.A)
-                response = dns.query.udp(q, address, sock=s)
+                response = dns.query.udp(q, address, sock=s, timeout=2)
                 rrs = response.get_rrset(response.answer, qname,
                                          dns.rdataclass.IN, dns.rdatatype.A)
                 self.assertTrue(rrs is not None)
@@ -104,7 +104,7 @@ class QueryTests(unittest.TestCase):
         for address in query_addresses:
             qname = dns.name.from_text('dns.google.')
             q = dns.message.make_query(qname, dns.rdatatype.A)
-            response = dns.query.tcp(q, address)
+            response = dns.query.tcp(q, address, timeout=2)
             rrs = response.get_rrset(response.answer, qname,
                                      dns.rdataclass.IN, dns.rdatatype.A)
             self.assertTrue(rrs is not None)
@@ -117,11 +117,12 @@ class QueryTests(unittest.TestCase):
             with socket.socket(dns.inet.af_for_address(address),
                                socket.SOCK_STREAM) as s:
                 ll = dns.inet.low_level_address_tuple((address, 53))
+                s.settimeout(2)
                 s.connect(ll)
                 s.setblocking(0)
                 qname = dns.name.from_text('dns.google.')
                 q = dns.message.make_query(qname, dns.rdatatype.A)
-                response = dns.query.tcp(q, None, sock=s)
+                response = dns.query.tcp(q, None, sock=s, timeout=2)
                 rrs = response.get_rrset(response.answer, qname,
                                          dns.rdataclass.IN, dns.rdatatype.A)
                 self.assertTrue(rrs is not None)
@@ -133,7 +134,7 @@ class QueryTests(unittest.TestCase):
         for address in query_addresses:
             qname = dns.name.from_text('dns.google.')
             q = dns.message.make_query(qname, dns.rdatatype.A)
-            response = dns.query.tls(q, address)
+            response = dns.query.tls(q, address, timeout=2)
             rrs = response.get_rrset(response.answer, qname,
                                      dns.rdataclass.IN, dns.rdatatype.A)
             self.assertTrue(rrs is not None)
@@ -147,13 +148,14 @@ class QueryTests(unittest.TestCase):
             with socket.socket(dns.inet.af_for_address(address),
                                socket.SOCK_STREAM) as base_s:
                 ll = dns.inet.low_level_address_tuple((address, 853))
+                base_s.settimeout(2)
                 base_s.connect(ll)
                 ctx = ssl.create_default_context()
                 with ctx.wrap_socket(base_s, server_hostname='dns.google') as s:
                     s.setblocking(0)
                     qname = dns.name.from_text('dns.google.')
                     q = dns.message.make_query(qname, dns.rdatatype.A)
-                    response = dns.query.tls(q, None, sock=s)
+                    response = dns.query.tls(q, None, sock=s, timeout=2)
                     rrs = response.get_rrset(response.answer, qname,
                                              dns.rdataclass.IN, dns.rdatatype.A)
                     self.assertTrue(rrs is not None)
@@ -165,7 +167,7 @@ class QueryTests(unittest.TestCase):
         for address in query_addresses:
             qname = dns.name.from_text('.')
             q = dns.message.make_query(qname, dns.rdatatype.DNSKEY)
-            (_, tcp) = dns.query.udp_with_fallback(q, address)
+            (_, tcp) = dns.query.udp_with_fallback(q, address, timeout=2)
             self.assertTrue(tcp)
 
     def testQueryUDPFallbackWithSocket(self):
@@ -175,20 +177,22 @@ class QueryTests(unittest.TestCase):
                 udp_s.setblocking(0)
                 with socket.socket(af, socket.SOCK_STREAM) as tcp_s:
                     ll = dns.inet.low_level_address_tuple((address, 53))
+                    tcp_s.settimeout(2)
                     tcp_s.connect(ll)
                     tcp_s.setblocking(0)
                     qname = dns.name.from_text('.')
                     q = dns.message.make_query(qname, dns.rdatatype.DNSKEY)
                     (_, tcp) = dns.query.udp_with_fallback(q, address,
-                                                          udp_sock=udp_s,
-                                                          tcp_sock=tcp_s)
+                                                           udp_sock=udp_s,
+                                                           tcp_sock=tcp_s,
+                                                           timeout=2)
                     self.assertTrue(tcp)
 
     def testQueryUDPFallbackNoFallback(self):
         for address in query_addresses:
             qname = dns.name.from_text('dns.google.')
             q = dns.message.make_query(qname, dns.rdatatype.A)
-            (_, tcp) = dns.query.udp_with_fallback(q, address)
+            (_, tcp) = dns.query.udp_with_fallback(q, address, timeout=2)
             self.assertFalse(tcp)
 
     def testUDPReceiveQuery(self):