]> git.ipfire.org Git - thirdparty/dnspython.git/commitdiff
make get_backend() shorter; improve sniffing; fail if we cannot tell the library
authorBob Halley <halley@dnspython.org>
Fri, 12 Jun 2020 20:46:26 +0000 (13:46 -0700)
committerBob Halley <halley@dnspython.org>
Fri, 12 Jun 2020 20:46:26 +0000 (13:46 -0700)
dns/asyncbackend.py

index 23256fedd76693a7ecfec7905e8787766ed19a38..acef9a67a508ef9ae43432a0bb9ff8cd3e8df6a1 100644 (file)
@@ -1,14 +1,18 @@
 # Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
 
+import dns.exception
 
 from dns._asyncbackend import Socket, DatagramSocket, \
     StreamSocket, Backend, low_level_address_tuple  # noqa:
 
 
 _default_backend = None
-_trio_backend = None
-_curio_backend = None
-_asyncio_backend = None
+
+_backends = {}
+
+
+class AsyncLibraryNotFoundError(dns.exception.DNSException):
+    pass
 
 
 def get_backend(name):
@@ -19,30 +23,22 @@ def get_backend(name):
 
     Raises NotImplementError if an unknown backend name is specified.
     """
+    backend = _backends.get(name)
+    if backend:
+        return backend
     if name == 'trio':
-        global _trio_backend
-        if _trio_backend:
-            return _trio_backend
         import dns._trio_backend
-        _trio_backend = dns._trio_backend.Backend()
-        return _trio_backend
+        backend = dns._trio_backend.Backend()
     elif name == 'curio':
-        global _curio_backend
-        if _curio_backend:
-            return _curio_backend
         import dns._curio_backend
-        _curio_backend = dns._curio_backend.Backend()
-        return _curio_backend
+        backend = dns._curio_backend.Backend()
     elif name == 'asyncio':
-        global _asyncio_backend
-        if _asyncio_backend:
-            return _asyncio_backend
         import dns._asyncio_backend
-        _asyncio_backend = dns._asyncio_backend.Backend()
-        return _asyncio_backend
+        backend = dns._asyncio_backend.Backend()
     else:
         raise NotImplementedError(f'unimplemented async backend {name}')
-
+    _backends[name] = backend
+    return backend
 
 def sniff():
     """Attempt to determine the in-use asynchronous I/O library by using
@@ -51,14 +47,24 @@ def sniff():
     Returns the name of the library, defaulting to "asyncio" if no other
     library appears to be in use.
     """
-    name = 'asyncio'
     try:
         import sniffio
-        name = sniffio.current_async_library()
-    except Exception:
-        pass
-    return name
-
+        try:
+            return sniffio.current_async_library()
+        except sniffio.AsyncLibraryNotFoundError:
+            raise AsyncLibraryNotFoundError('sniffio cannot determine ' +
+                                            'async library')
+    except ImportError:
+        import asyncio
+        try:
+            asyncio.get_running_loop()
+            return 'asyncio'
+        except RuntimeError:
+            raise AsyncLibraryNotFoundError('no async library detected')
+        except AttributeError:
+            # we have to check current_task on 3.6
+            if not asyncio.Task.current_task():
+                raise AsyncLibraryNotFoundError('no async library detected')
 
 def get_default_backend():
     """Get the default backend, initializing it if necessary.