]> git.ipfire.org Git - thirdparty/dnspython.git/commitdiff
Fix DoQ for asyncio IPv6 [#958].
authorBob Halley <halley@dnspython.org>
Mon, 17 Jul 2023 22:24:59 +0000 (15:24 -0700)
committerBob Halley <halley@dnspython.org>
Mon, 17 Jul 2023 22:24:59 +0000 (15:24 -0700)
dns/quic/_asyncio.py
tests/nanoquic.py
tests/test_doq.py

index f01ebc331a1a656aff7ace3553e889ff19f6cf95..30de36ae45f9b9a86f69eb65fbe37329bcba6566 100644 (file)
@@ -88,8 +88,10 @@ class AsyncioQuicConnection(AsyncQuicConnection):
         try:
             af = dns.inet.af_for_address(self._address)
             backend = dns.asyncbackend.get_backend("asyncio")
+            # Note that peer is a low-level address tuple, but make_socket() wants
+            # a high-level address tuple, so we convert.
             self._socket = await backend.make_socket(
-                af, socket.SOCK_DGRAM, 0, self._source, self._peer
+                af, socket.SOCK_DGRAM, 0, self._source, (self._peer[0], self._peer[1])
             )
             self._socket_created.set()
             async with self._socket:
@@ -184,6 +186,9 @@ class AsyncioQuicConnection(AsyncQuicConnection):
             self._manager.closed(self._peer[0], self._peer[1])
             self._closed = True
             self._connection.close()
+            if not self._socket_created.is_set():
+                # sender might be blocked on this, so set it
+                self._socket_created.set()
             async with self._wake_timer:
                 self._wake_timer.notify_all()
             try:
index 7599fcc6b1a3942f421e14298cb6b9c06bdbf5d9..b4efef2fa4c8ef0af37df526519cc474f73bd0e1 100644 (file)
@@ -14,7 +14,6 @@ try:
     import dns.asyncquery
     import dns.message
     import dns.rcode
-
     from tests.util import here
 
     have_quic = True
@@ -79,8 +78,9 @@ try:
             return struct.pack("!H", len(wire)) + wire
 
     class Server(threading.Thread):
-        def __init__(self):
+        def __init__(self, address="127.0.0.1"):
             super().__init__()
+            self.address = address
             self.transport = None
             self.protocol = None
             self.left = None
@@ -91,6 +91,7 @@ try:
             self.left, self.right = socket.socketpair()
             self.start()
             self.ready.wait(4)
+            return self
 
         def __exit__(self, ex_ty, ex_va, ex_tr):
             if self.protocol is not None:
@@ -116,7 +117,7 @@ try:
                 lambda: aioquic.asyncio.server.QuicServer(
                     configuration=conf, create_protocol=NanoQuic
                 ),
-                local_addr=("127.0.0.1", 8853),
+                local_addr=(self.address, 8853),
             )
             self.ready.set()
             try:
index 7d48bc22c002acb0630b68baba89a4f7624225aa..c43c06545bb6383f80967ef3ede55ba6845fd72b 100644 (file)
@@ -11,7 +11,7 @@ import dns.message
 import dns.query
 import dns.rcode
 
-from .util import here
+from .util import have_ipv4, have_ipv6, here
 
 try:
     from .nanoquic import Server
@@ -24,25 +24,37 @@ except ImportError:
         pass
 
 
+addresses = []
+if have_ipv4():
+    addresses.append("127.0.0.1")
+if have_ipv6():
+    addresses.append("::1")
+if len(addresses) == 0:
+    # no networking
+    _nanoquic_available = False
+
+
 @pytest.mark.skipif(not _nanoquic_available, reason="requires aioquic")
 def test_basic_sync():
-    with Server() as server:
-        q = dns.message.make_query("www.example.", "A")
-        r = dns.query.quic(q, "127.0.0.1", port=8853, verify=here("tls/ca.crt"))
-        assert r.rcode() == dns.rcode.REFUSED
+    q = dns.message.make_query("www.example.", "A")
+    for address in addresses:
+        with Server(address) as server:
+            r = dns.query.quic(q, address, port=8853, verify=here("tls/ca.crt"))
+            assert r.rcode() == dns.rcode.REFUSED
 
 
-async def amain():
+async def amain(address):
     q = dns.message.make_query("www.example.", "A")
-    r = await dns.asyncquery.quic(q, "127.0.0.1", port=8853, verify=here("tls/ca.crt"))
+    r = await dns.asyncquery.quic(q, address, port=8853, verify=here("tls/ca.crt"))
     assert r.rcode() == dns.rcode.REFUSED
 
 
 @pytest.mark.skipif(not _nanoquic_available, reason="requires aioquic")
 def test_basic_asyncio():
     dns.asyncbackend.set_default_backend("asyncio")
-    with Server() as server:
-        asyncio.run(amain())
+    for address in addresses:
+        with Server(address) as server:
+            asyncio.run(amain(address))
 
 
 try:
@@ -51,8 +63,9 @@ try:
     @pytest.mark.skipif(not _nanoquic_available, reason="requires aioquic")
     def test_basic_trio():
         dns.asyncbackend.set_default_backend("trio")
-        with Server() as server:
-            trio.run(amain)
+        for address in addresses:
+            with Server(address) as server:
+                trio.run(amain, address)
 
 except ImportError:
     pass