From: Brian Wellington Date: Thu, 7 Dec 2023 17:59:08 +0000 (-0800) Subject: Fix setting source for sync/trio quic queries. X-Git-Tag: v2.5.0rc1~16 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=186922d1f083d52401d0cbd13a0d1b929b3db3cb;p=thirdparty%2Fdnspython.git Fix setting source for sync/trio quic queries. The sync code called connect() before bind(), which meant that any attempt to specify a source resulted in an exception. This switches the order. The trio code called a nonexistent method in the wrong place, so didn't work at all. This fixes the call and puts it in the right place. The asyncio code worked, so no changes were needed. --- diff --git a/dns/quic/_sync.py b/dns/quic/_sync.py index d6731c90..120cb5f3 100644 --- a/dns/quic/_sync.py +++ b/dns/quic/_sync.py @@ -82,10 +82,6 @@ class SyncQuicConnection(BaseQuicConnection): def __init__(self, connection, address, port, source, source_port, manager): super().__init__(connection, address, port, source, source_port, manager) self._socket = socket.socket(self._af, socket.SOCK_DGRAM, 0) - self._socket.connect(self._peer) - (self._send_wakeup, self._receive_wakeup) = socket.socketpair() - self._receive_wakeup.setblocking(False) - self._socket.setblocking(False) if self._source is not None: try: self._socket.bind( @@ -94,6 +90,10 @@ class SyncQuicConnection(BaseQuicConnection): except Exception: self._socket.close() raise + self._socket.connect(self._peer) + (self._send_wakeup, self._receive_wakeup) = socket.socketpair() + self._receive_wakeup.setblocking(False) + self._socket.setblocking(False) self._handshake_complete = threading.Event() self._worker_thread = None self._lock = threading.Lock() diff --git a/dns/quic/_trio.py b/dns/quic/_trio.py index 0284c982..35e36b98 100644 --- a/dns/quic/_trio.py +++ b/dns/quic/_trio.py @@ -76,8 +76,6 @@ class TrioQuicConnection(AsyncQuicConnection): def __init__(self, connection, address, port, source, source_port, manager=None): super().__init__(connection, address, port, source, source_port, manager) self._socket = trio.socket.socket(self._af, socket.SOCK_DGRAM, 0) - if self._source: - trio.socket.bind(dns.inet.low_level_address_tuple(self._source, self._af)) self._handshake_complete = trio.Event() self._run_done = trio.Event() self._worker_scope = None @@ -85,6 +83,10 @@ class TrioQuicConnection(AsyncQuicConnection): async def _worker(self): try: + if self._source: + await self._socket.bind( + dns.inet.low_level_address_tuple(self._source, self._af) + ) await self._socket.connect(self._peer) while not self._done: (expiration, interval) = self._get_timer_values(False)