]> git.ipfire.org Git - thirdparty/dnspython.git/commitdiff
Add tests.util.is_internet_reachable()
authorBenjamin Drung <bdrung@ubuntu.com>
Mon, 16 May 2022 10:39:40 +0000 (12:39 +0200)
committerBenjamin Drung <bdrung@ubuntu.com>
Mon, 16 May 2022 10:39:40 +0000 (12:39 +0200)
Introduce `tests.util.is_internet_reachable` to avoid duplicate code.

Signed-off-by: Benjamin Drung <bdrung@ubuntu.com>
tests/test_async.py
tests/test_doh.py
tests/test_query.py
tests/test_resolver.py
tests/test_resolver_override.py
tests/util.py

index 3c9a7e6d123fdd97b20fcfc7cef83af52179d89c..21758d9ec2928b5053112b2534eb1295d1f2208a 100644 (file)
@@ -31,7 +31,7 @@ import dns.query
 import dns.rdataclass
 import dns.rdatatype
 import dns.resolver
-
+import tests.util
 
 # Some tests require TLS so skip those if it's not there.
 ssl = dns.query.ssl
@@ -42,15 +42,6 @@ except Exception:
     _ssl_available = False
 
 
-# Some tests require the internet to be available to run, so let's
-# skip those if it's not there.
-_network_available = True
-try:
-    socket.gethostbyname("dnspython.org")
-except socket.gaierror:
-    _network_available = False
-
-
 # Look for systemd-resolved, as it does dangling CNAME responses incorrectly.
 #
 # Currently we simply check if the nameserver is 127.0.0.53.
@@ -178,7 +169,7 @@ class MiscQuery(unittest.TestCase):
         self.assertEqual(t, ("::", 53))
 
 
-@unittest.skipIf(not _network_available, "Internet not reachable")
+@unittest.skipIf(not tests.util.is_internet_reachable(), "Internet not reachable")
 class AsyncTests(unittest.TestCase):
     connect_udp = sys.platform == "win32"
 
index bae28dcbc9abff0d0d5fad518d2d65a1350615ad..c4d6476cd1ac8cdeb0df09f93f73aa2bdb10c905 100644 (file)
@@ -38,6 +38,8 @@ if dns.query._have_requests:
 if dns.query._have_httpx:
     import httpx
 
+import tests.util
+
 # Probe for IPv4 and IPv6
 resolver_v4_addresses = []
 resolver_v6_addresses = []
@@ -75,17 +77,9 @@ KNOWN_PAD_AWARE_DOH_RESOLVER_URLS = [
     "https://dns.google/dns-query",
 ]
 
-# Some tests require the internet to be available to run, so let's
-# skip those if it's not there.
-_network_available = True
-try:
-    socket.gethostbyname("dnspython.org")
-except socket.gaierror:
-    _network_available = False
-
 
 @unittest.skipUnless(
-    dns.query._have_requests and _network_available,
+    dns.query._have_requests and tests.util.is_internet_reachable(),
     "Python requests cannot be imported; no DNS over HTTPS (DOH)",
 )
 class DNSOverHTTPSTestCaseRequests(unittest.TestCase):
@@ -165,7 +159,7 @@ class DNSOverHTTPSTestCaseRequests(unittest.TestCase):
 
 
 @unittest.skipUnless(
-    dns.query._have_httpx and _network_available and _have_ssl,
+    dns.query._have_httpx and tests.util.is_internet_reachable() and _have_ssl,
     "Python httpx cannot be imported; no DNS over HTTPS (DOH)",
 )
 class DNSOverHTTPSTestCaseHttpx(unittest.TestCase):
index 88d6375345e1ba88685af8bbfc405a9b836b9d95..ed2f112dfff082ba9b95344e5a40895f56ddadf4 100644 (file)
@@ -36,14 +36,7 @@ import dns.rdatatype
 import dns.query
 import dns.tsigkeyring
 import dns.zone
-
-# Some tests require the internet to be available to run, so let's
-# skip those if it's not there.
-_network_available = True
-try:
-    socket.gethostbyname("dnspython.org")
-except socket.gaierror:
-    _network_available = False
+import tests.util
 
 # Some tests use a "nano nameserver" for testing.  It requires trio
 # and threading, so try to import it and if it doesn't work, skip
@@ -77,7 +70,7 @@ for (af, address) in (
 keyring = dns.tsigkeyring.from_text({"name": "tDz6cfXXGtNivRpQ98hr6A=="})
 
 
-@unittest.skipIf(not _network_available, "Internet not reachable")
+@unittest.skipIf(not tests.util.is_internet_reachable(), "Internet not reachable")
 class QueryTests(unittest.TestCase):
     def testQueryUDP(self):
         for address in query_addresses:
index dc2dde4a1b716c345e4fc921c4c3e1081e92b123..9b8ea3d0285007f110466bce73598aa960e14492 100644 (file)
@@ -32,14 +32,7 @@ import dns.rdatatype
 import dns.resolver
 import dns.tsig
 import dns.tsigkeyring
-
-# Some tests require the internet to be available to run, so let's
-# skip those if it's not there.
-_network_available = True
-try:
-    socket.gethostbyname("dnspython.org")
-except socket.gaierror:
-    _network_available = False
+import tests.util
 
 # Some tests use a "nano nameserver" for testing.  It requires trio
 # and threading, so try to import it and if it doesn't work, skip
@@ -628,7 +621,7 @@ class BaseResolverTests(unittest.TestCase):
 keyname = dns.name.from_text("keyname")
 
 
-@unittest.skipIf(not _network_available, "Internet not reachable")
+@unittest.skipIf(not tests.util.is_internet_reachable(), "Internet not reachable")
 class LiveResolverTests(unittest.TestCase):
     def testZoneForName1(self):
         name = dns.name.from_text("www.dnspython.org.")
@@ -990,7 +983,7 @@ class NaptrNanoNameserver(Server):
 
 
 @unittest.skipIf(
-    not (_network_available and _nanonameserver_available),
+    not (tests.util.is_internet_reachable() and _nanonameserver_available),
     "Internet and NanoAuth required",
 )
 class NanoTests(unittest.TestCase):
@@ -1057,7 +1050,7 @@ class AlwaysNoErrorNoDataNanoNameserver(Server):
 
 
 @unittest.skipIf(
-    not (_network_available and _nanonameserver_available),
+    not (tests.util.is_internet_reachable() and _nanonameserver_available),
     "Internet and NanoAuth required",
 )
 class ZoneForNameTests(unittest.TestCase):
@@ -1109,7 +1102,7 @@ class FormErrNanoNameserver(Server):
 
 
 @pytest.mark.skipif(
-    not (_network_available and _nanonameserver_available),
+    not (tests.util.is_internet_reachable() and _nanonameserver_available),
     reason="Internet and NanoAuth required",
 )
 def testResolverTimeout():
@@ -1136,7 +1129,7 @@ def testResolverTimeout():
 
 
 @pytest.mark.skipif(
-    not (_network_available and _nanonameserver_available),
+    not (tests.util.is_internet_reachable() and _nanonameserver_available),
     reason="Internet and NanoAuth required",
 )
 def testResolverNoNameservers():
@@ -1167,7 +1160,7 @@ class SlowAlwaysType3NXDOMAINNanoNameserver(Server):
 
 
 @pytest.mark.skipif(
-    not (_network_available and _nanonameserver_available),
+    not (tests.util.is_internet_reachable() and _nanonameserver_available),
     reason="Internet and NanoAuth required",
 )
 def testZoneForNameLifetimeTimeout():
index 3d79445db2ac3a833d7ef398761daa7c97f9f405..aed7a53d10726e2bae1d0e5fbb931bb26fd0bf97 100644 (file)
@@ -8,17 +8,10 @@ import dns.name
 import dns.rdataclass
 import dns.rdatatype
 import dns.resolver
+import tests.util
 
-# Some tests require the internet to be available to run, so let's
-# skip those if it's not there.
-_network_available = True
-try:
-    socket.gethostbyname("dnspython.org")
-except socket.gaierror:
-    _network_available = False
 
-
-@unittest.skipIf(not _network_available, "Internet not reachable")
+@unittest.skipIf(not tests.util.is_internet_reachable(), "Internet not reachable")
 class OverrideSystemResolverTestCase(unittest.TestCase):
     def setUp(self):
         self.res = dns.resolver.Resolver(configure=False)
@@ -242,7 +235,7 @@ class OverrideSystemResolverUsingFakeResolverTestCase(unittest.TestCase):
             socket.gethostbyaddr("bogus")
 
 
-@unittest.skipIf(not _network_available, "Internet not reachable")
+@unittest.skipIf(not tests.util.is_internet_reachable(), "Internet not reachable")
 class OverrideSystemResolverUsingDefaultResolverTestCase(unittest.TestCase):
     def setUp(self):
         self.res = FakeResolver()
index df9ab444625746ab3444bdfd1d01e58eed5a5e4a..c8f9704b2f50eb16cdfc36ec5e9358d0053bd8c3 100644 (file)
 import enum
 import inspect
 import os.path
+import socket
+
+# Cache for is_internet_reachable()
+_internet_reachable = None
 
 
 def here(filename):
     return os.path.join(os.path.dirname(__file__), filename)
 
 
+def is_internet_reachable():
+    """Check if the Internet is reachable.
+
+    The result is cached.
+    """
+    global _internet_reachable
+    if _internet_reachable is None:
+        try:
+            socket.gethostbyname("dnspython.org")
+            _internet_reachable = True
+        except socket.gaierror:
+            _internet_reachable = False
+    return _internet_reachable
+
+
 def enumerate_module(module, super_class):
     """Yield module attributes which are subclasses of given class"""
     for attr_name in dir(module):