From: Tom Christie Date: Sat, 7 Dec 2019 11:09:58 +0000 (+0000) Subject: Backend operations like .read(), .write() now have a manadatory timeout argument... X-Git-Tag: 0.9.4~15 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=bc54dd0399065e24e3886fda5b5246b925d8242a;p=thirdparty%2Fhttpx.git Backend operations like .read(), .write() now have a manadatory timeout argument. (#611) --- diff --git a/httpx/concurrency/asyncio.py b/httpx/concurrency/asyncio.py index eee4ec32..89384bbb 100644 --- a/httpx/concurrency/asyncio.py +++ b/httpx/concurrency/asyncio.py @@ -40,14 +40,10 @@ def ssl_monkey_patch() -> None: class SocketStream(BaseSocketStream): def __init__( - self, - stream_reader: asyncio.StreamReader, - stream_writer: asyncio.StreamWriter, - timeout: Timeout, + self, stream_reader: asyncio.StreamReader, stream_writer: asyncio.StreamWriter, ): self.stream_reader = stream_reader self.stream_writer = stream_writer - self.timeout = timeout self.read_lock = asyncio.Lock() self._inner: typing.Optional[SocketStream] = None @@ -114,7 +110,7 @@ class SocketStream(BaseSocketStream): transport=transport, protocol=protocol, reader=stream_reader, loop=loop ) - ssl_stream = SocketStream(stream_reader, stream_writer, self.timeout) + ssl_stream = SocketStream(stream_reader, stream_writer) # When we return a new SocketStream with new StreamReader/StreamWriter instances # we need to keep references to the old StreamReader/StreamWriter so that they # are not garbage collected and closed while we're still using them. @@ -130,12 +126,7 @@ class SocketStream(BaseSocketStream): ident = ssl_object.selected_alpn_protocol() return "HTTP/2" if ident == "h2" else "HTTP/1.1" - async def read( - self, n: int, timeout: Timeout = None, flag: TimeoutFlag = None - ) -> bytes: - if timeout is None: - timeout = self.timeout - + async def read(self, n: int, timeout: Timeout, flag: TimeoutFlag = None) -> bytes: while True: # Check our flag at the first possible moment, and use a fine # grained retry loop if we're not yet in read-timeout mode. @@ -161,14 +152,11 @@ class SocketStream(BaseSocketStream): return data async def write( - self, data: bytes, timeout: Timeout = None, flag: TimeoutFlag = None + self, data: bytes, timeout: Timeout, flag: TimeoutFlag = None ) -> None: if not data: return - if timeout is None: - timeout = self.timeout - self.stream_writer.write(data) while True: try: @@ -269,9 +257,7 @@ class AsyncioBackend(ConcurrencyBackend): except asyncio.TimeoutError: raise ConnectTimeout() - return SocketStream( - stream_reader=stream_reader, stream_writer=stream_writer, timeout=timeout - ) + return SocketStream(stream_reader=stream_reader, stream_writer=stream_writer) async def open_uds_stream( self, @@ -292,9 +278,7 @@ class AsyncioBackend(ConcurrencyBackend): except asyncio.TimeoutError: raise ConnectTimeout() - return SocketStream( - stream_reader=stream_reader, stream_writer=stream_writer, timeout=timeout - ) + return SocketStream(stream_reader=stream_reader, stream_writer=stream_writer) async def run_in_threadpool( self, func: typing.Callable, *args: typing.Any, **kwargs: typing.Any diff --git a/httpx/concurrency/trio.py b/httpx/concurrency/trio.py index 39ac51c5..11737f72 100644 --- a/httpx/concurrency/trio.py +++ b/httpx/concurrency/trio.py @@ -21,10 +21,9 @@ def _or_inf(value: typing.Optional[float]) -> float: class SocketStream(BaseSocketStream): def __init__( - self, stream: typing.Union[trio.SocketStream, trio.SSLStream], timeout: Timeout, + self, stream: typing.Union[trio.SocketStream, trio.SSLStream], ) -> None: self.stream = stream - self.timeout = timeout self.read_lock = trio.Lock() self.write_lock = trio.Lock() @@ -42,7 +41,7 @@ class SocketStream(BaseSocketStream): if cancel_scope.cancelled_caught: raise ConnectTimeout() - return SocketStream(ssl_stream, self.timeout) + return SocketStream(ssl_stream) def get_http_version(self) -> str: if not isinstance(self.stream, trio.SSLStream): @@ -51,12 +50,7 @@ class SocketStream(BaseSocketStream): ident = self.stream.selected_alpn_protocol() return "HTTP/2" if ident == "h2" else "HTTP/1.1" - async def read( - self, n: int, timeout: Timeout = None, flag: TimeoutFlag = None - ) -> bytes: - if timeout is None: - timeout = self.timeout - + async def read(self, n: int, timeout: Timeout, flag: TimeoutFlag = None) -> bytes: while True: # Check our flag at the first possible moment, and use a fine # grained retry loop if we're not yet in read-timeout mode. @@ -86,14 +80,11 @@ class SocketStream(BaseSocketStream): return stream.socket.is_readable() async def write( - self, data: bytes, timeout: Timeout = None, flag: TimeoutFlag = None + self, data: bytes, timeout: Timeout, flag: TimeoutFlag = None ) -> None: if not data: return - if timeout is None: - timeout = self.timeout - write_timeout = _or_inf(timeout.write_timeout) while True: @@ -166,7 +157,7 @@ class TrioBackend(ConcurrencyBackend): if cancel_scope.cancelled_caught: raise ConnectTimeout() - return SocketStream(stream=stream, timeout=timeout) + return SocketStream(stream=stream) async def open_uds_stream( self, @@ -186,7 +177,7 @@ class TrioBackend(ConcurrencyBackend): if cancel_scope.cancelled_caught: raise ConnectTimeout() - return SocketStream(stream=stream, timeout=timeout) + return SocketStream(stream=stream) async def run_in_threadpool( self, func: typing.Callable, *args: typing.Any, **kwargs: typing.Any diff --git a/tests/test_concurrency.py b/tests/test_concurrency.py index a3b1df71..a3286c55 100644 --- a/tests/test_concurrency.py +++ b/tests/test_concurrency.py @@ -16,7 +16,7 @@ def get_trio_cipher(stream): return stream.stream.cipher() if isinstance(stream.stream, trio.SSLStream) else None -async def read_response(stream, timeout: float, should_contain: bytes) -> bytes: +async def read_response(stream, timeout: Timeout, should_contain: bytes) -> bytes: # stream.read() only gives us *up to* as much data as we ask for. In order to # cleanly close the stream, we must read until the end of the HTTP response. response = b"" @@ -57,7 +57,7 @@ async def test_start_tls_on_tcp_socket_stream(https_server, backend, get_cipher) assert stream.is_connection_dropped() is False assert get_cipher(stream) is not None - await stream.write(b"GET / HTTP/1.1\r\n\r\n") + await stream.write(b"GET / HTTP/1.1\r\n\r\n", timeout) response = await read_response(stream, timeout, should_contain=b"Hello, world") assert response.startswith(b"HTTP/1.1 200 OK\r\n") @@ -89,7 +89,7 @@ async def test_start_tls_on_uds_socket_stream(https_uds_server, backend, get_cip assert stream.is_connection_dropped() is False assert get_cipher(stream) is not None - await stream.write(b"GET / HTTP/1.1\r\n\r\n") + await stream.write(b"GET / HTTP/1.1\r\n\r\n", timeout) response = await read_response(stream, timeout, should_contain=b"Hello, world") assert response.startswith(b"HTTP/1.1 200 OK\r\n") @@ -105,10 +105,11 @@ async def test_concurrent_read(server, backend): stream = await backend.open_tcp_stream( server.url.host, server.url.port, ssl_context=None, timeout=Timeout(5) ) + timeout = Timeout(5) try: - await stream.write(b"GET / HTTP/1.1\r\n\r\n") + await stream.write(b"GET / HTTP/1.1\r\n\r\n", timeout) await run_concurrently( - backend, lambda: stream.read(10), lambda: stream.read(10) + backend, lambda: stream.read(10, timeout), lambda: stream.read(10, timeout) ) finally: await stream.close()