From: Bob Halley Date: Fri, 12 Jun 2020 03:01:23 +0000 (-0700) Subject: Attempt to cope with python 3.6 asyncio. X-Git-Tag: v2.0.0rc1~112^2~21 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=6caf2090dd77879ebe662dbc7e2492f705851723;p=thirdparty%2Fdnspython.git Attempt to cope with python 3.6 asyncio. --- diff --git a/dns/_asyncio_backend.py b/dns/_asyncio_backend.py index 42c6e662..07e9e5e2 100644 --- a/dns/_asyncio_backend.py +++ b/dns/_asyncio_backend.py @@ -8,6 +8,14 @@ import asyncio import dns._asyncbackend import dns.exception + +def _get_running_loop(): + try: + return asyncio.get_running_loop() + except AttributeError: + return asyncio.get_event_loop() + + class _DatagramProtocol: def __init__(self): self.transport = None @@ -42,6 +50,7 @@ async def _maybe_wait_for(awaitable, timeout): else: return await awaitable + class DatagramSocket(dns._asyncbackend.DatagramSocket): def __init__(self, family, transport, protocol): self.family = family @@ -53,7 +62,7 @@ class DatagramSocket(dns._asyncbackend.DatagramSocket): self.transport.sendto(what, destination) async def recvfrom(self, timeout): - done = asyncio.get_running_loop().create_future() + done = _get_running_loop().create_future() assert self.protocol.recvfrom is None self.protocol.recvfrom = done await _maybe_wait_for(done, timeout) @@ -84,7 +93,10 @@ class StreamSocket(dns._asyncbackend.DatagramSocket): async def close(self): self.writer.close() - await self.writer.wait_closed() + try: + await self.writer.wait_closed() + except AttributeError: + pass async def getpeername(self): return self.reader.get_extra_info('peername') @@ -97,7 +109,7 @@ class Backend(dns._asyncbackend.Backend): async def make_socket(self, af, socktype, proto=0, source=None, destination=None, timeout=None, ssl_context=None, server_hostname=None): - loop = asyncio.get_running_loop() + loop = _get_running_loop() if socktype == socket.SOCK_DGRAM: transport, protocol = await loop.create_datagram_endpoint( _DatagramProtocol, source, family=af, diff --git a/tests/test_async.py b/tests/test_async.py index c09941b9..ef07bb14 100644 --- a/tests/test_async.py +++ b/tests/test_async.py @@ -43,7 +43,14 @@ class AsyncTests(unittest.TestCase): self.backend = dns.asyncbackend.set_default_backend('asyncio') def async_run(self, afunc): - return asyncio.run(afunc()) + try: + runner = asyncio.run + except AttributeError: + def old_runner(awaitable): + loop = asyncio.get_event_loop() + return loop.run_until_complete(awaitable) + runner = old_runner + return runner(afunc()) def testResolve(self): async def run():