]> git.ipfire.org Git - thirdparty/dnspython.git/commitdiff
Fix setting source for sync/trio quic queries.
authorBrian Wellington <bwelling@xbill.org>
Thu, 7 Dec 2023 17:59:08 +0000 (09:59 -0800)
committerBrian Wellington <bwelling@xbill.org>
Thu, 7 Dec 2023 18:01:06 +0000 (10:01 -0800)
The sync code called connect() before bind(), which meant that any
attempt to specify a source resulted in an exception.  This switches the
order.

The trio code called a nonexistent method in the wrong place, so didn't
work at all.  This fixes the call and puts it in the right place.

The asyncio code worked, so no changes were needed.

dns/quic/_sync.py
dns/quic/_trio.py

index d6731c904c34605dff3a78afec9d2e22f0bbb81f..120cb5f329c779a6556143339e536d1443bf3ddf 100644 (file)
@@ -82,10 +82,6 @@ class SyncQuicConnection(BaseQuicConnection):
     def __init__(self, connection, address, port, source, source_port, manager):
         super().__init__(connection, address, port, source, source_port, manager)
         self._socket = socket.socket(self._af, socket.SOCK_DGRAM, 0)
-        self._socket.connect(self._peer)
-        (self._send_wakeup, self._receive_wakeup) = socket.socketpair()
-        self._receive_wakeup.setblocking(False)
-        self._socket.setblocking(False)
         if self._source is not None:
             try:
                 self._socket.bind(
@@ -94,6 +90,10 @@ class SyncQuicConnection(BaseQuicConnection):
             except Exception:
                 self._socket.close()
                 raise
+        self._socket.connect(self._peer)
+        (self._send_wakeup, self._receive_wakeup) = socket.socketpair()
+        self._receive_wakeup.setblocking(False)
+        self._socket.setblocking(False)
         self._handshake_complete = threading.Event()
         self._worker_thread = None
         self._lock = threading.Lock()
index 0284c98294feb1ddcf406e0a8ccc39b791d6e097..35e36b982f71df873fd5ac70edd46179a8091ab4 100644 (file)
@@ -76,8 +76,6 @@ class TrioQuicConnection(AsyncQuicConnection):
     def __init__(self, connection, address, port, source, source_port, manager=None):
         super().__init__(connection, address, port, source, source_port, manager)
         self._socket = trio.socket.socket(self._af, socket.SOCK_DGRAM, 0)
-        if self._source:
-            trio.socket.bind(dns.inet.low_level_address_tuple(self._source, self._af))
         self._handshake_complete = trio.Event()
         self._run_done = trio.Event()
         self._worker_scope = None
@@ -85,6 +83,10 @@ class TrioQuicConnection(AsyncQuicConnection):
 
     async def _worker(self):
         try:
+            if self._source:
+                await self._socket.bind(
+                    dns.inet.low_level_address_tuple(self._source, self._af)
+                )
             await self._socket.connect(self._peer)
             while not self._done:
                 (expiration, interval) = self._get_timer_values(False)