From: Tom Christie Date: Sun, 8 Dec 2019 19:43:33 +0000 (+0000) Subject: Drop TimeoutFlag (#618) X-Git-Tag: 0.9.4~12 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=83fc0921c175f47607f864678b0b0bba4dbb0413;p=thirdparty%2Fhttpx.git Drop TimeoutFlag (#618) --- diff --git a/httpx/concurrency/asyncio.py b/httpx/concurrency/asyncio.py index 89384bbb..51006959 100644 --- a/httpx/concurrency/asyncio.py +++ b/httpx/concurrency/asyncio.py @@ -6,13 +6,7 @@ import typing from ..config import PoolLimits, Timeout from ..exceptions import ConnectTimeout, PoolTimeout, ReadTimeout, WriteTimeout -from .base import ( - BaseEvent, - BasePoolSemaphore, - BaseSocketStream, - ConcurrencyBackend, - TimeoutFlag, -) +from .base import BaseEvent, BasePoolSemaphore, BaseSocketStream, ConcurrencyBackend SSL_MONKEY_PATCH_APPLIED = False @@ -126,51 +120,26 @@ class SocketStream(BaseSocketStream): ident = ssl_object.selected_alpn_protocol() return "HTTP/2" if ident == "h2" else "HTTP/1.1" - async def read(self, n: int, timeout: Timeout, flag: TimeoutFlag = None) -> bytes: - while True: - # Check our flag at the first possible moment, and use a fine - # grained retry loop if we're not yet in read-timeout mode. - should_raise = flag is None or flag.raise_on_read_timeout - read_timeout = timeout.read_timeout if should_raise else 0.01 - try: - async with self.read_lock: - data = await asyncio.wait_for( - self.stream_reader.read(n), read_timeout - ) - except asyncio.TimeoutError: - if should_raise: - raise ReadTimeout() from None - # FIX(py3.6): yield control back to the event loop to give it a chance - # to cancel `.read(n)` before we retry. - # This prevents concurrent `.read()` calls, which asyncio - # doesn't seem to allow on 3.6. - # See: https://github.com/encode/httpx/issues/382 - await asyncio.sleep(0) - else: - break - - return data + async def read(self, n: int, timeout: Timeout) -> bytes: + try: + async with self.read_lock: + return await asyncio.wait_for( + self.stream_reader.read(n), timeout.read_timeout + ) + except asyncio.TimeoutError: + raise ReadTimeout() from None - async def write( - self, data: bytes, timeout: Timeout, flag: TimeoutFlag = None - ) -> None: + async def write(self, data: bytes, timeout: Timeout) -> None: if not data: return self.stream_writer.write(data) - while True: - try: - await asyncio.wait_for( # type: ignore - self.stream_writer.drain(), timeout.write_timeout - ) - break - except asyncio.TimeoutError: - # We check our flag at the first possible moment, in order to - # allow us to suppress write timeouts, if we've since - # switched over to read-timeout mode. - should_raise = flag is None or flag.raise_on_write_timeout - if should_raise: - raise WriteTimeout() from None + try: + return await asyncio.wait_for( + self.stream_writer.drain(), timeout.write_timeout + ) + except asyncio.TimeoutError: + raise WriteTimeout() from None def is_connection_dropped(self) -> bool: # Counter-intuitively, what we really want to know here is whether the socket is diff --git a/httpx/concurrency/base.py b/httpx/concurrency/base.py index 24b47232..27ce9c89 100644 --- a/httpx/concurrency/base.py +++ b/httpx/concurrency/base.py @@ -26,38 +26,6 @@ def lookup_backend( raise RuntimeError(f"Unknown or unsupported concurrency backend {backend!r}") -class TimeoutFlag: - """ - A timeout flag holds a state of either read-timeout or write-timeout mode. - - We use this so that we can attempt both reads and writes concurrently, while - only enforcing timeouts in one direction. - - During a request/response cycle we start in write-timeout mode. - - Once we've sent a request fully, or once we start seeing a response, - then we switch to read-timeout mode instead. - """ - - def __init__(self) -> None: - self.raise_on_read_timeout = False - self.raise_on_write_timeout = True - - def set_read_timeouts(self) -> None: - """ - Set the flag to read-timeout mode. - """ - self.raise_on_read_timeout = True - self.raise_on_write_timeout = False - - def set_write_timeouts(self) -> None: - """ - Set the flag to write-timeout mode. - """ - self.raise_on_read_timeout = False - self.raise_on_write_timeout = True - - class BaseSocketStream: """ A socket stream with read/write operations. Abstracts away any asyncio-specific @@ -73,7 +41,7 @@ class BaseSocketStream: ) -> "BaseSocketStream": raise NotImplementedError() # pragma: no cover - async def read(self, n: int, timeout: Timeout, flag: typing.Any = None) -> bytes: + async def read(self, n: int, timeout: Timeout) -> bytes: raise NotImplementedError() # pragma: no cover async def write(self, data: bytes, timeout: Timeout) -> None: diff --git a/httpx/concurrency/trio.py b/httpx/concurrency/trio.py index 11737f72..7604e965 100644 --- a/httpx/concurrency/trio.py +++ b/httpx/concurrency/trio.py @@ -6,16 +6,10 @@ import trio from ..config import PoolLimits, Timeout from ..exceptions import ConnectTimeout, PoolTimeout, ReadTimeout, WriteTimeout -from .base import ( - BaseEvent, - BasePoolSemaphore, - BaseSocketStream, - ConcurrencyBackend, - TimeoutFlag, -) +from .base import BaseEvent, BasePoolSemaphore, BaseSocketStream, ConcurrencyBackend -def _or_inf(value: typing.Optional[float]) -> float: +def none_as_inf(value: typing.Optional[float]) -> float: return value if value is not None else float("inf") @@ -30,18 +24,16 @@ class SocketStream(BaseSocketStream): async def start_tls( self, hostname: str, ssl_context: ssl.SSLContext, timeout: Timeout ) -> "SocketStream": - connect_timeout = _or_inf(timeout.connect_timeout) + connect_timeout = none_as_inf(timeout.connect_timeout) ssl_stream = trio.SSLStream( self.stream, ssl_context=ssl_context, server_hostname=hostname ) - with trio.move_on_after(connect_timeout) as cancel_scope: + with trio.move_on_after(connect_timeout): await ssl_stream.do_handshake() + return SocketStream(ssl_stream) - if cancel_scope.cancelled_caught: - raise ConnectTimeout() - - return SocketStream(ssl_stream) + raise ConnectTimeout() def get_http_version(self) -> str: if not isinstance(self.stream, trio.SSLStream): @@ -50,19 +42,26 @@ class SocketStream(BaseSocketStream): ident = self.stream.selected_alpn_protocol() return "HTTP/2" if ident == "h2" else "HTTP/1.1" - async def read(self, n: int, timeout: Timeout, flag: TimeoutFlag = None) -> bytes: - while True: - # Check our flag at the first possible moment, and use a fine - # grained retry loop if we're not yet in read-timeout mode. - should_raise = flag is None or flag.raise_on_read_timeout - read_timeout = _or_inf(timeout.read_timeout if should_raise else 0.01) + async def read(self, n: int, timeout: Timeout) -> bytes: + read_timeout = none_as_inf(timeout.read_timeout) + + with trio.move_on_after(read_timeout): + async with self.read_lock: + return await self.stream.receive_some(max_bytes=n) + + raise ReadTimeout() + + async def write(self, data: bytes, timeout: Timeout) -> None: + if not data: + return - with trio.move_on_after(read_timeout): - async with self.read_lock: - return await self.stream.receive_some(max_bytes=n) + write_timeout = none_as_inf(timeout.write_timeout) - if should_raise: - raise ReadTimeout() from None + with trio.move_on_after(write_timeout): + async with self.write_lock: + return await self.stream.send_all(data) + + raise WriteTimeout() def is_connection_dropped(self) -> bool: # Adapted from: https://github.com/encode/httpx/pull/143#issuecomment-515202982 @@ -79,26 +78,6 @@ class SocketStream(BaseSocketStream): # See: https://github.com/encode/httpx/pull/143#issuecomment-515181778 return stream.socket.is_readable() - async def write( - self, data: bytes, timeout: Timeout, flag: TimeoutFlag = None - ) -> None: - if not data: - return - - write_timeout = _or_inf(timeout.write_timeout) - - while True: - with trio.move_on_after(write_timeout): - async with self.write_lock: - await self.stream.send_all(data) - break - # We check our flag at the first possible moment, in order to - # allow us to suppress write timeouts, if we've since - # switched over to read-timeout mode. - should_raise = flag is None or flag.raise_on_write_timeout - if should_raise: - raise WriteTimeout() from None - async def close(self) -> None: await self.stream.aclose() @@ -123,7 +102,7 @@ class PoolSemaphore(BasePoolSemaphore): if self.semaphore is None: return - timeout = _or_inf(timeout) + timeout = none_as_inf(timeout) with trio.move_on_after(timeout): await self.semaphore.acquire() @@ -146,18 +125,16 @@ class TrioBackend(ConcurrencyBackend): ssl_context: typing.Optional[ssl.SSLContext], timeout: Timeout, ) -> SocketStream: - connect_timeout = _or_inf(timeout.connect_timeout) + connect_timeout = none_as_inf(timeout.connect_timeout) - with trio.move_on_after(connect_timeout) as cancel_scope: + with trio.move_on_after(connect_timeout): stream: trio.SocketStream = await trio.open_tcp_stream(hostname, port) if ssl_context is not None: stream = trio.SSLStream(stream, ssl_context, server_hostname=hostname) await stream.do_handshake() + return SocketStream(stream=stream) - if cancel_scope.cancelled_caught: - raise ConnectTimeout() - - return SocketStream(stream=stream) + raise ConnectTimeout() async def open_uds_stream( self, @@ -166,18 +143,16 @@ class TrioBackend(ConcurrencyBackend): ssl_context: typing.Optional[ssl.SSLContext], timeout: Timeout, ) -> SocketStream: - connect_timeout = _or_inf(timeout.connect_timeout) + connect_timeout = none_as_inf(timeout.connect_timeout) - with trio.move_on_after(connect_timeout) as cancel_scope: + with trio.move_on_after(connect_timeout): stream: trio.SocketStream = await trio.open_unix_socket(path) if ssl_context is not None: stream = trio.SSLStream(stream, ssl_context, server_hostname=hostname) await stream.do_handshake() + return SocketStream(stream=stream) - if cancel_scope.cancelled_caught: - raise ConnectTimeout() - - return SocketStream(stream=stream) + raise ConnectTimeout() async def run_in_threadpool( self, func: typing.Callable, *args: typing.Any, **kwargs: typing.Any diff --git a/httpx/dispatch/http11.py b/httpx/dispatch/http11.py index 44961f02..903e7d14 100644 --- a/httpx/dispatch/http11.py +++ b/httpx/dispatch/http11.py @@ -2,7 +2,7 @@ import typing import h11 -from ..concurrency.base import BaseSocketStream, TimeoutFlag +from ..concurrency.base import BaseSocketStream from ..config import Timeout from ..exceptions import ConnectionClosed, ProtocolError from ..models import Request, Response @@ -38,7 +38,6 @@ class HTTP11Connection: self.socket = socket self.on_release = on_release self.h11_state = h11.Connection(our_role=h11.CLIENT) - self.timeout_flag = TimeoutFlag() async def send(self, request: Request, timeout: Timeout = None) -> Response: timeout = Timeout() if timeout is None else timeout @@ -102,9 +101,6 @@ class HTTP11Connection: # care about connection errors that occur when sending the body. # Ignore these, and defer to any exceptions on reading the response. self.h11_state.send_failed() - finally: - # Once we've sent the request, we enable read timeouts. - self.timeout_flag.set_read_timeouts() async def _send_event(self, event: H11Event, timeout: Timeout) -> None: """ @@ -122,9 +118,6 @@ class HTTP11Connection: """ while True: event = await self._receive_event(timeout) - # As soon as we start seeing response events, we should enable - # read timeouts, if we haven't already. - self.timeout_flag.set_read_timeouts() if isinstance(event, h11.InformationalResponse): continue else: @@ -171,9 +164,7 @@ class HTTP11Connection: if event is h11.NEED_DATA: try: - data = await self.socket.read( - self.READ_NUM_BYTES, timeout, flag=self.timeout_flag - ) + data = await self.socket.read(self.READ_NUM_BYTES, timeout) except OSError: # pragma: nocover data = b"" self.h11_state.receive_data(data) @@ -194,7 +185,6 @@ class HTTP11Connection: ): # Get ready for another request/response cycle. self.h11_state.start_next_cycle() - self.timeout_flag.set_write_timeouts() else: await self.close()