From 186922d1f083d52401d0cbd13a0d1b929b3db3cb Mon Sep 17 00:00:00 2001 From: Brian Wellington Date: Thu, 7 Dec 2023 09:59:08 -0800 Subject: [PATCH] Fix setting source for sync/trio quic queries. The sync code called connect() before bind(), which meant that any attempt to specify a source resulted in an exception. This switches the order. The trio code called a nonexistent method in the wrong place, so didn't work at all. This fixes the call and puts it in the right place. The asyncio code worked, so no changes were needed. --- dns/quic/_sync.py | 8 ++++---- dns/quic/_trio.py | 6 ++++-- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/dns/quic/_sync.py b/dns/quic/_sync.py index d6731c90..120cb5f3 100644 --- a/dns/quic/_sync.py +++ b/dns/quic/_sync.py @@ -82,10 +82,6 @@ class SyncQuicConnection(BaseQuicConnection): def __init__(self, connection, address, port, source, source_port, manager): super().__init__(connection, address, port, source, source_port, manager) self._socket = socket.socket(self._af, socket.SOCK_DGRAM, 0) - self._socket.connect(self._peer) - (self._send_wakeup, self._receive_wakeup) = socket.socketpair() - self._receive_wakeup.setblocking(False) - self._socket.setblocking(False) if self._source is not None: try: self._socket.bind( @@ -94,6 +90,10 @@ class SyncQuicConnection(BaseQuicConnection): except Exception: self._socket.close() raise + self._socket.connect(self._peer) + (self._send_wakeup, self._receive_wakeup) = socket.socketpair() + self._receive_wakeup.setblocking(False) + self._socket.setblocking(False) self._handshake_complete = threading.Event() self._worker_thread = None self._lock = threading.Lock() diff --git a/dns/quic/_trio.py b/dns/quic/_trio.py index 0284c982..35e36b98 100644 --- a/dns/quic/_trio.py +++ b/dns/quic/_trio.py @@ -76,8 +76,6 @@ class TrioQuicConnection(AsyncQuicConnection): def __init__(self, connection, address, port, source, source_port, manager=None): super().__init__(connection, address, port, source, source_port, manager) self._socket = trio.socket.socket(self._af, socket.SOCK_DGRAM, 0) - if self._source: - trio.socket.bind(dns.inet.low_level_address_tuple(self._source, self._af)) self._handshake_complete = trio.Event() self._run_done = trio.Event() self._worker_scope = None @@ -85,6 +83,10 @@ class TrioQuicConnection(AsyncQuicConnection): async def _worker(self): try: + if self._source: + await self._socket.bind( + dns.inet.low_level_address_tuple(self._source, self._af) + ) await self._socket.connect(self._peer) while not self._done: (expiration, interval) = self._get_timer_values(False) -- 2.47.3