]> git.ipfire.org Git - thirdparty/dnspython.git/commitdiff
Ensure asyncio datagram sockets on windows have had a bind() before
authorBob Halley <halley@dnspython.org>
Fri, 15 Dec 2023 02:04:39 +0000 (18:04 -0800)
committerBob Halley <halley@dnspython.org>
Fri, 15 Dec 2023 02:04:39 +0000 (18:04 -0800)
recvfrom().

The fix for [#637] erroneously concluded that that windows asyncio
needed connected datagram sockets, but subsequent further
investation showed that the actual problem was that windows wants
an unconnected datagram socket to be bound before recvfrom is called.
Linux autobinds in this case to the wildcard address and port, so
that's why we didn't see any problems there.  We now ensure that
the source is bound.

dns/_asyncio_backend.py
tests/test_async.py

index 2631228ecdc95684f1d30980780f3300bf81de9b..7d4d1b54cbc4a90edc9c668ef547696441d8ad67 100644 (file)
@@ -8,6 +8,7 @@ import sys
 
 import dns._asyncbackend
 import dns.exception
+import dns.inet
 
 _is_win32 = sys.platform == "win32"
 
@@ -224,14 +225,12 @@ class Backend(dns._asyncbackend.Backend):
         ssl_context=None,
         server_hostname=None,
     ):
-        if destination is None and socktype == socket.SOCK_DGRAM and _is_win32:
-            raise NotImplementedError(
-                "destinationless datagram sockets "
-                "are not supported by asyncio "
-                "on Windows"
-            )
         loop = _get_running_loop()
         if socktype == socket.SOCK_DGRAM:
+            if _is_win32 and source is None:
+                # Win32 wants explicit binding before recvfrom().  This is the
+                # proper fix for [#637].
+                source = (dns.inet.any_for_af(af), 0)
             transport, protocol = await loop.create_datagram_endpoint(
                 _DatagramProtocol,
                 source,
@@ -266,7 +265,7 @@ class Backend(dns._asyncbackend.Backend):
         await asyncio.sleep(interval)
 
     def datagram_connection_required(self):
-        return _is_win32
+        return False
 
     def get_transport_class(self):
         return _HTTPTransport
index d0f977a2147e7acb91557c8c11fceae7d116b2fd..ac32431c28907ada331df27b4dd143edb96bf871 100644 (file)
@@ -171,8 +171,6 @@ class MiscQuery(unittest.TestCase):
 
 @unittest.skipIf(not tests.util.is_internet_reachable(), "Internet not reachable")
 class AsyncTests(unittest.TestCase):
-    connect_udp = sys.platform == "win32"
-
     def setUp(self):
         self.backend = dns.asyncbackend.set_default_backend("asyncio")
 
@@ -327,12 +325,12 @@ class AsyncTests(unittest.TestCase):
             qname = dns.name.from_text("dns.google.")
 
             async def run():
-                if self.connect_udp:
-                    dtuple = (address, 53)
-                else:
-                    dtuple = None
                 async with await self.backend.make_socket(
-                    dns.inet.af_for_address(address), socket.SOCK_DGRAM, 0, None, dtuple
+                    dns.inet.af_for_address(address),
+                    socket.SOCK_DGRAM,
+                    0,
+                    None,
+                    None,
                 ) as s:
                     q = dns.message.make_query(qname, dns.rdatatype.A)
                     return await dns.asyncquery.udp(q, address, sock=s, timeout=2)
@@ -485,9 +483,6 @@ class AsyncTests(unittest.TestCase):
             self.assertFalse(tcp)
 
     def testUDPReceiveQuery(self):
-        if self.connect_udp:
-            self.skipTest("test needs connectionless sockets")
-
         async def run():
             async with await self.backend.make_socket(
                 socket.AF_INET, socket.SOCK_DGRAM, source=("127.0.0.1", 0)
@@ -509,9 +504,6 @@ class AsyncTests(unittest.TestCase):
         self.assertEqual(sender_address, recv_address)
 
     def testUDPReceiveTimeout(self):
-        if self.connect_udp:
-            self.skipTest("test needs connectionless sockets")
-
         async def arun():
             async with await self.backend.make_socket(
                 socket.AF_INET, socket.SOCK_DGRAM, 0, ("127.0.0.1", 0)
@@ -616,8 +608,6 @@ class AsyncTests(unittest.TestCase):
 
 @unittest.skipIf(not tests.util.is_internet_reachable(), "Internet not reachable")
 class AsyncioOnlyTests(unittest.TestCase):
-    connect_udp = sys.platform == "win32"
-
     def setUp(self):
         self.backend = dns.asyncbackend.set_default_backend("asyncio")
 
@@ -625,9 +615,6 @@ class AsyncioOnlyTests(unittest.TestCase):
         return asyncio.run(afunc())
 
     def testUseAfterTimeout(self):
-        if self.connect_udp:
-            self.skipTest("test needs connectionless sockets")
-
         # Test #843 fix.
         async def run():
             qname = dns.name.from_text("dns.google")
@@ -678,8 +665,6 @@ try:
             return trio.run(afunc)
 
     class TrioAsyncTests(AsyncTests):
-        connect_udp = False
-
         def setUp(self):
             self.backend = dns.asyncbackend.set_default_backend("trio")