From: Bob Halley Date: Fri, 12 Jun 2020 20:46:26 +0000 (-0700) Subject: make get_backend() shorter; improve sniffing; fail if we cannot tell the library X-Git-Tag: v2.0.0rc1~112^2~10 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=41a111e852a4ea63727ba51abd668c9fec0fd599;p=thirdparty%2Fdnspython.git make get_backend() shorter; improve sniffing; fail if we cannot tell the library --- diff --git a/dns/asyncbackend.py b/dns/asyncbackend.py index 23256fed..acef9a67 100644 --- a/dns/asyncbackend.py +++ b/dns/asyncbackend.py @@ -1,14 +1,18 @@ # Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license +import dns.exception from dns._asyncbackend import Socket, DatagramSocket, \ StreamSocket, Backend, low_level_address_tuple # noqa: _default_backend = None -_trio_backend = None -_curio_backend = None -_asyncio_backend = None + +_backends = {} + + +class AsyncLibraryNotFoundError(dns.exception.DNSException): + pass def get_backend(name): @@ -19,30 +23,22 @@ def get_backend(name): Raises NotImplementError if an unknown backend name is specified. """ + backend = _backends.get(name) + if backend: + return backend if name == 'trio': - global _trio_backend - if _trio_backend: - return _trio_backend import dns._trio_backend - _trio_backend = dns._trio_backend.Backend() - return _trio_backend + backend = dns._trio_backend.Backend() elif name == 'curio': - global _curio_backend - if _curio_backend: - return _curio_backend import dns._curio_backend - _curio_backend = dns._curio_backend.Backend() - return _curio_backend + backend = dns._curio_backend.Backend() elif name == 'asyncio': - global _asyncio_backend - if _asyncio_backend: - return _asyncio_backend import dns._asyncio_backend - _asyncio_backend = dns._asyncio_backend.Backend() - return _asyncio_backend + backend = dns._asyncio_backend.Backend() else: raise NotImplementedError(f'unimplemented async backend {name}') - + _backends[name] = backend + return backend def sniff(): """Attempt to determine the in-use asynchronous I/O library by using @@ -51,14 +47,24 @@ def sniff(): Returns the name of the library, defaulting to "asyncio" if no other library appears to be in use. """ - name = 'asyncio' try: import sniffio - name = sniffio.current_async_library() - except Exception: - pass - return name - + try: + return sniffio.current_async_library() + except sniffio.AsyncLibraryNotFoundError: + raise AsyncLibraryNotFoundError('sniffio cannot determine ' + + 'async library') + except ImportError: + import asyncio + try: + asyncio.get_running_loop() + return 'asyncio' + except RuntimeError: + raise AsyncLibraryNotFoundError('no async library detected') + except AttributeError: + # we have to check current_task on 3.6 + if not asyncio.Task.current_task(): + raise AsyncLibraryNotFoundError('no async library detected') def get_default_backend(): """Get the default backend, initializing it if necessary.