From 41a111e852a4ea63727ba51abd668c9fec0fd599 Mon Sep 17 00:00:00 2001 From: Bob Halley Date: Fri, 12 Jun 2020 13:46:26 -0700 Subject: [PATCH] make get_backend() shorter; improve sniffing; fail if we cannot tell the library --- dns/asyncbackend.py | 56 +++++++++++++++++++++++++-------------------- 1 file changed, 31 insertions(+), 25 deletions(-) diff --git a/dns/asyncbackend.py b/dns/asyncbackend.py index 23256fed..acef9a67 100644 --- a/dns/asyncbackend.py +++ b/dns/asyncbackend.py @@ -1,14 +1,18 @@ # Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license +import dns.exception from dns._asyncbackend import Socket, DatagramSocket, \ StreamSocket, Backend, low_level_address_tuple # noqa: _default_backend = None -_trio_backend = None -_curio_backend = None -_asyncio_backend = None + +_backends = {} + + +class AsyncLibraryNotFoundError(dns.exception.DNSException): + pass def get_backend(name): @@ -19,30 +23,22 @@ def get_backend(name): Raises NotImplementError if an unknown backend name is specified. """ + backend = _backends.get(name) + if backend: + return backend if name == 'trio': - global _trio_backend - if _trio_backend: - return _trio_backend import dns._trio_backend - _trio_backend = dns._trio_backend.Backend() - return _trio_backend + backend = dns._trio_backend.Backend() elif name == 'curio': - global _curio_backend - if _curio_backend: - return _curio_backend import dns._curio_backend - _curio_backend = dns._curio_backend.Backend() - return _curio_backend + backend = dns._curio_backend.Backend() elif name == 'asyncio': - global _asyncio_backend - if _asyncio_backend: - return _asyncio_backend import dns._asyncio_backend - _asyncio_backend = dns._asyncio_backend.Backend() - return _asyncio_backend + backend = dns._asyncio_backend.Backend() else: raise NotImplementedError(f'unimplemented async backend {name}') - + _backends[name] = backend + return backend def sniff(): """Attempt to determine the in-use asynchronous I/O library by using @@ -51,14 +47,24 @@ def sniff(): Returns the name of the library, defaulting to "asyncio" if no other library appears to be in use. """ - name = 'asyncio' try: import sniffio - name = sniffio.current_async_library() - except Exception: - pass - return name - + try: + return sniffio.current_async_library() + except sniffio.AsyncLibraryNotFoundError: + raise AsyncLibraryNotFoundError('sniffio cannot determine ' + + 'async library') + except ImportError: + import asyncio + try: + asyncio.get_running_loop() + return 'asyncio' + except RuntimeError: + raise AsyncLibraryNotFoundError('no async library detected') + except AttributeError: + # we have to check current_task on 3.6 + if not asyncio.Task.current_task(): + raise AsyncLibraryNotFoundError('no async library detected') def get_default_backend(): """Get the default backend, initializing it if necessary. -- 2.47.3