From: Bob Halley Date: Tue, 8 Aug 2023 01:16:42 +0000 (-0700) Subject: Ensure async https() requests are bounded in total time X-Git-Tag: v2.5.0rc1~53 X-Git-Url: http://git.ipfire.org/gitweb.cgi?a=commitdiff_plain;h=a22644d7ced90ee796592622eebd9629377b39ba;p=thirdparty%2Fdnspython.git Ensure async https() requests are bounded in total time according to the timeout [#978]. Unfortunately we do not currently have a good way to make this guarantee for sync https() calls. --- diff --git a/dns/_asyncbackend.py b/dns/_asyncbackend.py index cebcbdfd..49f14fed 100644 --- a/dns/_asyncbackend.py +++ b/dns/_asyncbackend.py @@ -94,3 +94,6 @@ class Backend: # pragma: no cover def get_transport_class(self): raise NotImplementedError + + async def wait_for(self, awaitable, timeout): + raise NotImplementedError diff --git a/dns/_asyncio_backend.py b/dns/_asyncio_backend.py index 0021f84f..2631228e 100644 --- a/dns/_asyncio_backend.py +++ b/dns/_asyncio_backend.py @@ -270,3 +270,6 @@ class Backend(dns._asyncbackend.Backend): def get_transport_class(self): return _HTTPTransport + + async def wait_for(self, awaitable, timeout): + return await _maybe_wait_for(awaitable, timeout) diff --git a/dns/_trio_backend.py b/dns/_trio_backend.py index d414f0b3..4d9fb820 100644 --- a/dns/_trio_backend.py +++ b/dns/_trio_backend.py @@ -237,3 +237,10 @@ class Backend(dns._asyncbackend.Backend): def get_transport_class(self): return _HTTPTransport + + async def wait_for(self, awaitable, timeout): + with _maybe_timeout(timeout): + return await awaitable + raise dns.exception.Timeout( + timeout=timeout + ) # pragma: no cover lgtm[py/unreachable-statement] diff --git a/dns/asyncquery.py b/dns/asyncquery.py index 737e1c92..ecf9c1a5 100644 --- a/dns/asyncquery.py +++ b/dns/asyncquery.py @@ -563,14 +563,14 @@ async def https( "content-length": str(len(wire)), } ) - response = await the_client.post( - url, headers=headers, content=wire, timeout=timeout + response = await backend.wait_for( + the_client.post(url, headers=headers, content=wire), timeout ) else: wire = base64.urlsafe_b64encode(wire).rstrip(b"=") twire = wire.decode() # httpx does a repr() if we give it bytes - response = await the_client.get( - url, headers=headers, timeout=timeout, params={"dns": twire} + response = await backend.wait_for( + the_client.get(url, headers=headers, params={"dns": twire}), timeout ) # see https://tools.ietf.org/html/rfc8484#section-4.2.1 for info about DoH