From: Bob Halley Date: Wed, 17 Jun 2020 14:46:36 +0000 (-0700) Subject: improve async coverage X-Git-Tag: v2.0.0rc1~68 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=cdbac65201b0fb8ce2623b3bb16b0ac6c6eb065b;p=thirdparty%2Fdnspython.git improve async coverage --- diff --git a/dns/asyncbackend.py b/dns/asyncbackend.py index 26f23976..f028417c 100644 --- a/dns/asyncbackend.py +++ b/dns/asyncbackend.py @@ -10,6 +10,8 @@ _default_backend = None _backends = {} +# Allow sniffio import to be disabled for testing purposes +_no_sniffio = False class AsyncLibraryNotFoundError(dns.exception.DNSException): pass @@ -40,6 +42,7 @@ def get_backend(name): _backends[name] = backend return backend + def sniff(): """Attempt to determine the in-use asynchronous I/O library by using the ``sniffio`` module if it is available. @@ -48,6 +51,8 @@ def sniff(): if the library cannot be determined. """ try: + if _no_sniffio: + raise ImportError import sniffio try: return sniffio.current_async_library() @@ -61,11 +66,12 @@ def sniff(): return 'asyncio' except RuntimeError: raise AsyncLibraryNotFoundError('no async library detected') - except AttributeError: + except AttributeError: # pragma: no cover # we have to check current_task on 3.6 if not asyncio.Task.current_task(): raise AsyncLibraryNotFoundError('no async library detected') + def get_default_backend(): """Get the default backend, initializing it if necessary. """ diff --git a/dns/asyncresolver.py b/dns/asyncresolver.py index 09098604..3ac334f5 100644 --- a/dns/asyncresolver.py +++ b/dns/asyncresolver.py @@ -253,5 +253,5 @@ async def zone_for_name(name, rdclass=dns.rdataclass.IN, tcp=False, pass try: name = name.parent() - except dns.name.NoParent: + except dns.name.NoParent: # pragma: no cover raise NoRootSOA diff --git a/tests/test_async.py b/tests/test_async.py index 15c7eaa1..42f4c69a 100644 --- a/tests/test_async.py +++ b/tests/test_async.py @@ -61,6 +61,94 @@ for (af, address) in ((socket.AF_INET, '8.8.8.8'), pass +class AsyncDetectionTests(unittest.TestCase): + sniff_result = 'asyncio' + + def async_run(self, afunc): + try: + runner = asyncio.run + except AttributeError: + # this is only needed for 3.6 + def old_runner(awaitable): + loop = asyncio.get_event_loop() + return loop.run_until_complete(awaitable) + runner = old_runner + return runner(afunc()) + + def test_sniff(self): + dns.asyncbackend._default_backend = None + async def run(): + self.assertEqual(dns.asyncbackend.sniff(), self.sniff_result) + self.async_run(run) + + def test_get_default_backend(self): + dns.asyncbackend._default_backend = None + async def run(): + backend = dns.asyncbackend.get_default_backend() + self.assertEqual(backend.name(), self.sniff_result) + self.async_run(run) + +class NoSniffioAsyncDetectionTests(AsyncDetectionTests): + expect_raise = False + + def setUp(self): + dns.asyncbackend._no_sniffio = True + + def tearDown(self): + dns.asyncbackend._no_sniffio = False + + def test_sniff(self): + dns.asyncbackend._default_backend = None + if self.expect_raise: + async def abad(): + dns.asyncbackend.sniff() + def bad(): + self.async_run(abad) + self.assertRaises(dns.asyncbackend.AsyncLibraryNotFoundError, bad) + else: + super().test_sniff() + + def test_get_default_backend(self): + dns.asyncbackend._default_backend = None + if self.expect_raise: + async def abad(): + dns.asyncbackend.get_default_backend() + def bad(): + self.async_run(abad) + self.assertRaises(dns.asyncbackend.AsyncLibraryNotFoundError, bad) + else: + super().test_get_default_backend() + + +class MiscBackend(unittest.TestCase): + def test_sniff_without_run_loop(self): + dns.asyncbackend._default_backend = None + def bad(): + dns.asyncbackend.sniff() + self.assertRaises(dns.asyncbackend.AsyncLibraryNotFoundError, bad) + + def test_bogus_backend(self): + def bad(): + dns.asyncbackend.get_backend('bogus') + self.assertRaises(NotImplementedError, bad) + + +class MiscQuery(unittest.TestCase): + def test_source_tuple(self): + t = dns.asyncquery._source_tuple(socket.AF_INET, None, 0) + self.assertEqual(t, None) + t = dns.asyncquery._source_tuple(socket.AF_INET6, None, 0) + self.assertEqual(t, None) + t = dns.asyncquery._source_tuple(socket.AF_INET, '1.2.3.4', 53) + self.assertEqual(t, ('1.2.3.4', 53)) + t = dns.asyncquery._source_tuple(socket.AF_INET6, '1::2', 53) + self.assertEqual(t, ('1::2', 53)) + t = dns.asyncquery._source_tuple(socket.AF_INET, None, 53) + self.assertEqual(t, ('0.0.0.0', 53)) + t = dns.asyncquery._source_tuple(socket.AF_INET6, None, 53) + self.assertEqual(t, ('::', 53)) + + @unittest.skipIf(not _network_available, "Internet not reachable") class AsyncTests(unittest.TestCase): @@ -93,6 +181,15 @@ class AsyncTests(unittest.TestCase): dnsgoogle = dns.name.from_text('dns.google.') self.assertEqual(answer[0].target, dnsgoogle) + def testResolverBadScheme(self): + res = dns.asyncresolver.Resolver() + res.nameservers = ['bogus://dns.google/dns-query'] + async def run(): + answer = await res.resolve('dns.google', 'A') + def bad(): + self.async_run(run) + self.assertRaises(dns.resolver.NoNameservers, bad) + def testZoneForName1(self): async def run(): name = dns.name.from_text('www.dnspython.org.') @@ -247,6 +344,17 @@ class AsyncTests(unittest.TestCase): try: import trio + import sniffio + + class TrioAsyncDetectionTests(AsyncDetectionTests): + sniff_result = 'trio' + def async_run(self, afunc): + return trio.run(afunc) + + class TrioNoSniffioAsyncDetectionTests(NoSniffioAsyncDetectionTests): + expect_raise = True + def async_run(self, afunc): + return trio.run(afunc) class TrioAsyncTests(AsyncTests): def setUp(self): @@ -259,6 +367,17 @@ except ImportError: try: import curio + import sniffio + + class CurioAsyncDetectionTests(AsyncDetectionTests): + sniff_result = 'curio' + def async_run(self, afunc): + return curio.run(afunc) + + class CurioNoSniffioAsyncDetectionTests(NoSniffioAsyncDetectionTests): + expect_raise = True + def async_run(self, afunc): + return curio.run(afunc) class CurioAsyncTests(AsyncTests): def setUp(self):