]> git.ipfire.org Git - thirdparty/dnspython.git/commitdiff
improve async coverage
authorBob Halley <halley@dnspython.org>
Wed, 17 Jun 2020 14:46:36 +0000 (07:46 -0700)
committerBob Halley <halley@dnspython.org>
Wed, 17 Jun 2020 14:46:36 +0000 (07:46 -0700)
dns/asyncbackend.py
dns/asyncresolver.py
tests/test_async.py

index 26f239769d2d7e919cf41240b6dd5984adbe1576..f028417c971474b6f10286753c89e0e487325116 100644 (file)
@@ -10,6 +10,8 @@ _default_backend = None
 
 _backends = {}
 
+# Allow sniffio import to be disabled for testing purposes
+_no_sniffio = False
 
 class AsyncLibraryNotFoundError(dns.exception.DNSException):
     pass
@@ -40,6 +42,7 @@ def get_backend(name):
     _backends[name] = backend
     return backend
 
+
 def sniff():
     """Attempt to determine the in-use asynchronous I/O library by using
     the ``sniffio`` module if it is available.
@@ -48,6 +51,8 @@ def sniff():
     if the library cannot be determined.
     """
     try:
+        if _no_sniffio:
+            raise ImportError
         import sniffio
         try:
             return sniffio.current_async_library()
@@ -61,11 +66,12 @@ def sniff():
             return 'asyncio'
         except RuntimeError:
             raise AsyncLibraryNotFoundError('no async library detected')
-        except AttributeError:
+        except AttributeError:  # pragma: no cover
             # we have to check current_task on 3.6
             if not asyncio.Task.current_task():
                 raise AsyncLibraryNotFoundError('no async library detected')
 
+
 def get_default_backend():
     """Get the default backend, initializing it if necessary.
     """
index 0909860481c2ffea3ca24d7bd9e94d7c683e29be..3ac334f505c42999aceb7eddea8fae03c28d70c3 100644 (file)
@@ -253,5 +253,5 @@ async def zone_for_name(name, rdclass=dns.rdataclass.IN, tcp=False,
             pass
         try:
             name = name.parent()
-        except dns.name.NoParent:
+        except dns.name.NoParent:  # pragma: no cover
             raise NoRootSOA
index 15c7eaa1ec789144123789ce51a7bdb898373c2a..42f4c69abad6504d3e7ec9324c289b2fcef59136 100644 (file)
@@ -61,6 +61,94 @@ for (af, address) in ((socket.AF_INET, '8.8.8.8'),
         pass
 
 
+class AsyncDetectionTests(unittest.TestCase):
+    sniff_result = 'asyncio'
+
+    def async_run(self, afunc):
+        try:
+            runner = asyncio.run
+        except AttributeError:
+            # this is only needed for 3.6
+            def old_runner(awaitable):
+                loop = asyncio.get_event_loop()
+                return loop.run_until_complete(awaitable)
+            runner = old_runner
+        return runner(afunc())
+
+    def test_sniff(self):
+        dns.asyncbackend._default_backend = None
+        async def run():
+            self.assertEqual(dns.asyncbackend.sniff(), self.sniff_result)
+        self.async_run(run)
+
+    def test_get_default_backend(self):
+        dns.asyncbackend._default_backend = None
+        async def run():
+            backend = dns.asyncbackend.get_default_backend()
+            self.assertEqual(backend.name(), self.sniff_result)
+        self.async_run(run)
+
+class NoSniffioAsyncDetectionTests(AsyncDetectionTests):
+    expect_raise = False
+
+    def setUp(self):
+        dns.asyncbackend._no_sniffio = True
+
+    def tearDown(self):
+        dns.asyncbackend._no_sniffio = False
+
+    def test_sniff(self):
+        dns.asyncbackend._default_backend = None
+        if self.expect_raise:
+            async def abad():
+                dns.asyncbackend.sniff()
+            def bad():
+                self.async_run(abad)
+            self.assertRaises(dns.asyncbackend.AsyncLibraryNotFoundError, bad)
+        else:
+            super().test_sniff()
+
+    def test_get_default_backend(self):
+        dns.asyncbackend._default_backend = None
+        if self.expect_raise:
+            async def abad():
+                dns.asyncbackend.get_default_backend()
+            def bad():
+                self.async_run(abad)
+            self.assertRaises(dns.asyncbackend.AsyncLibraryNotFoundError, bad)
+        else:
+            super().test_get_default_backend()
+
+
+class MiscBackend(unittest.TestCase):
+    def test_sniff_without_run_loop(self):
+        dns.asyncbackend._default_backend = None
+        def bad():
+            dns.asyncbackend.sniff()
+        self.assertRaises(dns.asyncbackend.AsyncLibraryNotFoundError, bad)
+
+    def test_bogus_backend(self):
+        def bad():
+            dns.asyncbackend.get_backend('bogus')
+        self.assertRaises(NotImplementedError, bad)
+
+
+class MiscQuery(unittest.TestCase):
+    def test_source_tuple(self):
+        t = dns.asyncquery._source_tuple(socket.AF_INET, None, 0)
+        self.assertEqual(t, None)
+        t = dns.asyncquery._source_tuple(socket.AF_INET6, None, 0)
+        self.assertEqual(t, None)
+        t = dns.asyncquery._source_tuple(socket.AF_INET, '1.2.3.4', 53)
+        self.assertEqual(t, ('1.2.3.4', 53))
+        t = dns.asyncquery._source_tuple(socket.AF_INET6, '1::2', 53)
+        self.assertEqual(t, ('1::2', 53))
+        t = dns.asyncquery._source_tuple(socket.AF_INET, None, 53)
+        self.assertEqual(t, ('0.0.0.0', 53))
+        t = dns.asyncquery._source_tuple(socket.AF_INET6, None, 53)
+        self.assertEqual(t, ('::', 53))
+
+
 @unittest.skipIf(not _network_available, "Internet not reachable")
 class AsyncTests(unittest.TestCase):
 
@@ -93,6 +181,15 @@ class AsyncTests(unittest.TestCase):
         dnsgoogle = dns.name.from_text('dns.google.')
         self.assertEqual(answer[0].target, dnsgoogle)
 
+    def testResolverBadScheme(self):
+        res = dns.asyncresolver.Resolver()
+        res.nameservers = ['bogus://dns.google/dns-query']
+        async def run():
+            answer = await res.resolve('dns.google', 'A')
+        def bad():
+            self.async_run(run)
+        self.assertRaises(dns.resolver.NoNameservers, bad)
+
     def testZoneForName1(self):
         async def run():
             name = dns.name.from_text('www.dnspython.org.')
@@ -247,6 +344,17 @@ class AsyncTests(unittest.TestCase):
 
 try:
     import trio
+    import sniffio
+
+    class TrioAsyncDetectionTests(AsyncDetectionTests):
+        sniff_result = 'trio'
+        def async_run(self, afunc):
+            return trio.run(afunc)
+
+    class TrioNoSniffioAsyncDetectionTests(NoSniffioAsyncDetectionTests):
+        expect_raise = True
+        def async_run(self, afunc):
+            return trio.run(afunc)
 
     class TrioAsyncTests(AsyncTests):
         def setUp(self):
@@ -259,6 +367,17 @@ except ImportError:
 
 try:
     import curio
+    import sniffio
+
+    class CurioAsyncDetectionTests(AsyncDetectionTests):
+        sniff_result = 'curio'
+        def async_run(self, afunc):
+            return curio.run(afunc)
+
+    class CurioNoSniffioAsyncDetectionTests(NoSniffioAsyncDetectionTests):
+        expect_raise = True
+        def async_run(self, afunc):
+            return curio.run(afunc)
 
     class CurioAsyncTests(AsyncTests):
         def setUp(self):