]> git.ipfire.org Git - thirdparty/dnspython.git/commitdiff
asyncio on Windows requries connected sockets. [Issue #637] 645/head
authorBob Halley <halley@dnspython.org>
Thu, 25 Feb 2021 16:55:15 +0000 (08:55 -0800)
committerBob Halley <halley@dnspython.org>
Thu, 25 Feb 2021 16:55:15 +0000 (08:55 -0800)
dns/_asyncbackend.py
dns/_asyncio_backend.py
dns/asyncquery.py
tests/test_async.py

index 0ce316b231d258821ef88e8cda903cbb92a45882..69411df6fd18bd09ee209b15c692f318b3010fc6 100644 (file)
@@ -64,3 +64,6 @@ class Backend:    # pragma: no cover
                           source=None, destination=None, timeout=None,
                           ssl_context=None, server_hostname=None):
         raise NotImplementedError
+
+    def datagram_connection_required(self):
+        return False
index 17bd0f788419d6215225503e144a0fc3e9ed1695..80c31dcd19e920745a7ea6028455072d6e097445 100644 (file)
@@ -4,11 +4,14 @@
 
 import socket
 import asyncio
+import sys
 
 import dns._asyncbackend
 import dns.exception
 
 
+_is_win32 = sys.platform == 'win32'
+
 def _get_running_loop():
     try:
         return asyncio.get_running_loop()
@@ -114,11 +117,16 @@ class Backend(dns._asyncbackend.Backend):
     async def make_socket(self, af, socktype, proto=0,
                           source=None, destination=None, timeout=None,
                           ssl_context=None, server_hostname=None):
+        if destination is None and socktype == socket.SOCK_DGRAM and \
+           _is_win32:
+            raise NotImplementedError('destinationless datagram sockets '
+                                      'are not supported by asyncio '
+                                      'on Windows')
         loop = _get_running_loop()
         if socktype == socket.SOCK_DGRAM:
             transport, protocol = await loop.create_datagram_endpoint(
                 _DatagramProtocol, source, family=af,
-                proto=proto)
+                proto=proto, remote_addr=destination)
             return DatagramSocket(af, transport, protocol)
         elif socktype == socket.SOCK_STREAM:
             (r, w) = await _maybe_wait_for(
@@ -136,3 +144,7 @@ class Backend(dns._asyncbackend.Backend):
 
     async def sleep(self, interval):
         await asyncio.sleep(interval)
+
+    def datagram_connection_required(self):
+        return _is_win32
+        
index 89c2622fee643a8b874838abeac207e03c937a2c..0e353e8ade1d384059ce4cf077b759d6c1443bba 100644 (file)
@@ -142,7 +142,12 @@ async def udp(q, where, timeout=None, port=53, source=None, source_port=0,
             if not backend:
                 backend = dns.asyncbackend.get_default_backend()
             stuple = _source_tuple(af, source, source_port)
-            s = await backend.make_socket(af, socket.SOCK_DGRAM, 0, stuple)
+            if backend.datagram_connection_required():
+                dtuple = (where, port)
+            else:
+                dtuple = None
+            s = await backend.make_socket(af, socket.SOCK_DGRAM, 0, stuple,
+                                          dtuple)
         await send_udp(s, wire, destination, expiration)
         (r, received_time, _) = await receive_udp(s, destination, expiration,
                                                   ignore_unexpected,
index e9a26bb387d859c0cbcfa9b958af9d34d335fad2..0252f2255d2786f81e4704c641f726b31cb2ed85 100644 (file)
@@ -17,6 +17,7 @@
 
 import asyncio
 import socket
+import sys
 import time
 import unittest
 
@@ -152,6 +153,7 @@ class MiscQuery(unittest.TestCase):
 
 @unittest.skipIf(not _network_available, "Internet not reachable")
 class AsyncTests(unittest.TestCase):
+    connect_udp = sys.platform == 'win32'
 
     def setUp(self):
         self.backend = dns.asyncbackend.set_default_backend('asyncio')
@@ -261,9 +263,13 @@ class AsyncTests(unittest.TestCase):
         for address in query_addresses:
             qname = dns.name.from_text('dns.google.')
             async def run():
+                if self.connect_udp:
+                    dtuple=(address, 53)
+                else:
+                    dtuple=None
                 async with await self.backend.make_socket(
                         dns.inet.af_for_address(address),
-                        socket.SOCK_DGRAM) as s:
+                        socket.SOCK_DGRAM, 0, None, dtuple) as s:
                     q = dns.message.make_query(qname, dns.rdatatype.A)
                     return await dns.asyncquery.udp(q, address, sock=s,
                                                     timeout=2)
@@ -373,6 +379,8 @@ class AsyncTests(unittest.TestCase):
             self.assertFalse(tcp)
 
     def testUDPReceiveQuery(self):
+        if self.connect_udp:
+            self.skipTest('test needs connectionless sockets')
         async def run():
             async with await self.backend.make_socket(
                     socket.AF_INET, socket.SOCK_DGRAM,
@@ -392,6 +400,8 @@ class AsyncTests(unittest.TestCase):
         self.assertEqual(sender_address, recv_address)
 
     def testUDPReceiveTimeout(self):
+        if self.connect_udp:
+            self.skipTest('test needs connectionless sockets')
         async def arun():
             async with await self.backend.make_socket(socket.AF_INET,
                                                       socket.SOCK_DGRAM, 0,
@@ -430,6 +440,7 @@ try:
             return trio.run(afunc)
 
     class TrioAsyncTests(AsyncTests):
+        connect_udp = False
         def setUp(self):
             self.backend = dns.asyncbackend.set_default_backend('trio')
 
@@ -453,6 +464,7 @@ try:
             return curio.run(afunc)
 
     class CurioAsyncTests(AsyncTests):
+        connect_udp = False
         def setUp(self):
             self.backend = dns.asyncbackend.set_default_backend('curio')