From 52c2dc1bc360d604081e0b981d69dc6c553a7b0a Mon Sep 17 00:00:00 2001 From: Bob Halley Date: Sat, 5 Aug 2023 13:35:29 -0700 Subject: [PATCH] Fix unintended "wait forever" behavior with zero timeouts [#976]. In a few places we did "if timeout:" or "if expiration:" when we really meant "if timeout is not None:". This meant that in the zero timeout case we fell into the "wait forever" path instead of immediately timing out. In the case of UDP queries, we'd be waiting on recvfrom() and if a packet was lost, then the code would never wake up. (cherry picked from commit 0c183f10c78941a4e72046d4dcb2ecf20083b398) --- dns/_asyncio_backend.py | 2 +- dns/_trio_backend.py | 2 +- dns/asyncquery.py | 2 +- dns/resolver.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/dns/_asyncio_backend.py b/dns/_asyncio_backend.py index 94f751b1..0021f84f 100644 --- a/dns/_asyncio_backend.py +++ b/dns/_asyncio_backend.py @@ -51,7 +51,7 @@ class _DatagramProtocol: async def _maybe_wait_for(awaitable, timeout): - if timeout: + if timeout is not None: try: return await asyncio.wait_for(awaitable, timeout) except asyncio.TimeoutError: diff --git a/dns/_trio_backend.py b/dns/_trio_backend.py index 14f05280..d414f0b3 100644 --- a/dns/_trio_backend.py +++ b/dns/_trio_backend.py @@ -13,7 +13,7 @@ import dns.inet def _maybe_timeout(timeout): - if timeout: + if timeout is not None: return trio.move_on_after(timeout) else: return dns._asyncbackend.NullContext() diff --git a/dns/asyncquery.py b/dns/asyncquery.py index 97295a29..737e1c92 100644 --- a/dns/asyncquery.py +++ b/dns/asyncquery.py @@ -72,7 +72,7 @@ def _source_tuple(af, address, port): def _timeout(expiration, now=None): - if expiration: + if expiration is not None: if not now: now = time.time() return max(expiration - now, 0) diff --git a/dns/resolver.py b/dns/resolver.py index bbac49a5..f08f824d 100644 --- a/dns/resolver.py +++ b/dns/resolver.py @@ -1697,7 +1697,7 @@ def zone_for_name( while 1: try: rlifetime: Optional[float] - if expiration: + if expiration is not None: rlifetime = expiration - time.time() if rlifetime <= 0: rlifetime = 0 -- 2.47.3