# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+import dns.exception
from dns._asyncbackend import Socket, DatagramSocket, \
StreamSocket, Backend, low_level_address_tuple # noqa:
_default_backend = None
-_trio_backend = None
-_curio_backend = None
-_asyncio_backend = None
+
+_backends = {}
+
+
+class AsyncLibraryNotFoundError(dns.exception.DNSException):
+ pass
def get_backend(name):
Raises NotImplementError if an unknown backend name is specified.
"""
+ backend = _backends.get(name)
+ if backend:
+ return backend
if name == 'trio':
- global _trio_backend
- if _trio_backend:
- return _trio_backend
import dns._trio_backend
- _trio_backend = dns._trio_backend.Backend()
- return _trio_backend
+ backend = dns._trio_backend.Backend()
elif name == 'curio':
- global _curio_backend
- if _curio_backend:
- return _curio_backend
import dns._curio_backend
- _curio_backend = dns._curio_backend.Backend()
- return _curio_backend
+ backend = dns._curio_backend.Backend()
elif name == 'asyncio':
- global _asyncio_backend
- if _asyncio_backend:
- return _asyncio_backend
import dns._asyncio_backend
- _asyncio_backend = dns._asyncio_backend.Backend()
- return _asyncio_backend
+ backend = dns._asyncio_backend.Backend()
else:
raise NotImplementedError(f'unimplemented async backend {name}')
-
+ _backends[name] = backend
+ return backend
def sniff():
"""Attempt to determine the in-use asynchronous I/O library by using
Returns the name of the library, defaulting to "asyncio" if no other
library appears to be in use.
"""
- name = 'asyncio'
try:
import sniffio
- name = sniffio.current_async_library()
- except Exception:
- pass
- return name
-
+ try:
+ return sniffio.current_async_library()
+ except sniffio.AsyncLibraryNotFoundError:
+ raise AsyncLibraryNotFoundError('sniffio cannot determine ' +
+ 'async library')
+ except ImportError:
+ import asyncio
+ try:
+ asyncio.get_running_loop()
+ return 'asyncio'
+ except RuntimeError:
+ raise AsyncLibraryNotFoundError('no async library detected')
+ except AttributeError:
+ # we have to check current_task on 3.6
+ if not asyncio.Task.current_task():
+ raise AsyncLibraryNotFoundError('no async library detected')
def get_default_backend():
"""Get the default backend, initializing it if necessary.