]> git.ipfire.org Git - thirdparty/dnspython.git/commitdiff
Attempt to cope with python 3.6 asyncio.
authorBob Halley <halley@dnspython.org>
Fri, 12 Jun 2020 03:01:23 +0000 (20:01 -0700)
committerBob Halley <halley@dnspython.org>
Fri, 12 Jun 2020 03:01:23 +0000 (20:01 -0700)
dns/_asyncio_backend.py
tests/test_async.py

index 42c6e662dff314440ad5ee1e55a0841d729a8b84..07e9e5e252121217791e00313ee2116b575296ba 100644 (file)
@@ -8,6 +8,14 @@ import asyncio
 import dns._asyncbackend
 import dns.exception
 
+
+def _get_running_loop():
+    try:
+        return asyncio.get_running_loop()
+    except AttributeError:
+        return asyncio.get_event_loop()
+
+
 class _DatagramProtocol:
     def __init__(self):
         self.transport = None
@@ -42,6 +50,7 @@ async def _maybe_wait_for(awaitable, timeout):
     else:
         return await awaitable
 
+
 class DatagramSocket(dns._asyncbackend.DatagramSocket):
     def __init__(self, family, transport, protocol):
         self.family = family
@@ -53,7 +62,7 @@ class DatagramSocket(dns._asyncbackend.DatagramSocket):
         self.transport.sendto(what, destination)
 
     async def recvfrom(self, timeout):
-        done = asyncio.get_running_loop().create_future()
+        done = _get_running_loop().create_future()
         assert self.protocol.recvfrom is None
         self.protocol.recvfrom = done
         await _maybe_wait_for(done, timeout)
@@ -84,7 +93,10 @@ class StreamSocket(dns._asyncbackend.DatagramSocket):
 
     async def close(self):
         self.writer.close()
-        await self.writer.wait_closed()
+        try:
+            await self.writer.wait_closed()
+        except AttributeError:
+            pass
 
     async def getpeername(self):
         return self.reader.get_extra_info('peername')
@@ -97,7 +109,7 @@ class Backend(dns._asyncbackend.Backend):
     async def make_socket(self, af, socktype, proto=0,
                           source=None, destination=None, timeout=None,
                           ssl_context=None, server_hostname=None):
-        loop = asyncio.get_running_loop()
+        loop = _get_running_loop()
         if socktype == socket.SOCK_DGRAM:
             transport, protocol = await loop.create_datagram_endpoint(
                 _DatagramProtocol, source, family=af,
index c09941b99727ec1d13ae2f1cc5e2905f7b516a82..ef07bb14a8529ed92d62984c17cf371d76de5673 100644 (file)
@@ -43,7 +43,14 @@ class AsyncTests(unittest.TestCase):
         self.backend = dns.asyncbackend.set_default_backend('asyncio')
 
     def async_run(self, afunc):
-        return asyncio.run(afunc())
+        try:
+            runner = asyncio.run
+        except AttributeError:
+            def old_runner(awaitable):
+                loop = asyncio.get_event_loop()
+                return loop.run_until_complete(awaitable)
+            runner = old_runner
+        return runner(afunc())
 
     def testResolve(self):
         async def run():