From: Tom Christie Date: Tue, 25 Jun 2019 11:54:14 +0000 (+0100) Subject: Read/Write timeout modes (#104) X-Git-Tag: 0.6.4~1 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=f4f05e343ceef86042cb4f6991ee6b924e434493;p=thirdparty%2Fhttpx.git Read/Write timeout modes (#104) * Read/Write timeout modes * Read/Write timeout modes --- diff --git a/http3/concurrency.py b/http3/concurrency.py index fd6af368..bd04c2da 100644 --- a/http3/concurrency.py +++ b/http3/concurrency.py @@ -49,6 +49,38 @@ def ssl_monkey_patch() -> None: MonkeyPatch.write = _fixed_write +class TimeoutFlag: + """ + A timeout flag holds a state of either read-timeout or write-timeout mode. + + We use this so that we can attempt both reads and writes concurrently, while + only enforcing timeouts in one direction. + + During a request/response cycle we start in write-timeout mode. + + Once we've sent a request fully, or once we start seeing a response, + then we switch to read-timeout mode instead. + """ + + def __init__(self) -> None: + self.raise_on_read_timeout = False + self.raise_on_write_timeout = True + + def set_read_timeouts(self) -> None: + """ + Set the flag to read-timeout mode. + """ + self.raise_on_read_timeout = True + self.raise_on_write_timeout = False + + def set_write_timeouts(self) -> None: + """ + Set the flag to write-timeout mode. + """ + self.raise_on_read_timeout = False + self.raise_on_write_timeout = True + + class Reader(BaseReader): def __init__( self, stream_reader: asyncio.StreamReader, timeout: TimeoutConfig @@ -56,16 +88,22 @@ class Reader(BaseReader): self.stream_reader = stream_reader self.timeout = timeout - async def read(self, n: int, timeout: TimeoutConfig = None) -> bytes: + async def read( + self, n: int, timeout: TimeoutConfig = None, flag: TimeoutFlag = None + ) -> bytes: if timeout is None: timeout = self.timeout - try: - data = await asyncio.wait_for( - self.stream_reader.read(n), timeout.read_timeout - ) - except asyncio.TimeoutError: - raise ReadTimeout() + while True: + should_raise = flag is None or flag.raise_on_read_timeout + try: + data = await asyncio.wait_for( + self.stream_reader.read(n), timeout.read_timeout + ) + break + except asyncio.TimeoutError: + if should_raise: + raise ReadTimeout() return data @@ -78,7 +116,9 @@ class Writer(BaseWriter): def write_no_block(self, data: bytes) -> None: self.stream_writer.write(data) # pragma: nocover - async def write(self, data: bytes, timeout: TimeoutConfig = None) -> None: + async def write( + self, data: bytes, timeout: TimeoutConfig = None, flag: TimeoutFlag = None + ) -> None: if not data: return @@ -86,12 +126,16 @@ class Writer(BaseWriter): timeout = self.timeout self.stream_writer.write(data) - try: - await asyncio.wait_for( # type: ignore - self.stream_writer.drain(), timeout.write_timeout - ) - except asyncio.TimeoutError: - raise WriteTimeout() + while True: + try: + await asyncio.wait_for( # type: ignore + self.stream_writer.drain(), timeout.write_timeout + ) + break + except asyncio.TimeoutError: + should_raise = flag is None or flag.raise_on_write_timeout + if should_raise: + raise WriteTimeout() async def close(self) -> None: self.stream_writer.close() diff --git a/http3/dispatch/http11.py b/http3/dispatch/http11.py index 1f632d8e..ae9e7460 100644 --- a/http3/dispatch/http11.py +++ b/http3/dispatch/http11.py @@ -2,6 +2,7 @@ import typing import h11 +from ..concurrency import TimeoutFlag from ..config import DEFAULT_TIMEOUT_CONFIG, TimeoutConfig, TimeoutTypes from ..exceptions import ConnectTimeout, ReadTimeout from ..interfaces import BaseReader, BaseWriter, ConcurrencyBackend @@ -38,6 +39,7 @@ class HTTP11Connection: self.backend = backend self.on_release = on_release self.h11_state = h11.Connection(our_role=h11.CLIENT) + self.timeout_flag = TimeoutFlag() async def send( self, request: AsyncRequest, timeout: TimeoutTypes = None @@ -103,6 +105,9 @@ class HTTP11Connection: # care about connection errors that occur when sending the body. # Ignore these, and defer to any exceptions on reading the response. self.h11_state.send_failed() + finally: + # Once we've sent the request, we enable read timeouts. + self.timeout_flag.set_read_timeouts() async def _send_event(self, event: H11Event, timeout: TimeoutConfig = None) -> None: """ @@ -120,6 +125,9 @@ class HTTP11Connection: """ while True: event = await self._receive_event(timeout) + # As soon as we start seeing response events, we should enable + # read timeouts, if we haven't already. + self.timeout_flag.set_read_timeouts() if isinstance(event, h11.InformationalResponse): continue else: @@ -149,7 +157,9 @@ class HTTP11Connection: event = self.h11_state.next_event() if event is h11.NEED_DATA: try: - data = await self.reader.read(self.READ_NUM_BYTES, timeout) + data = await self.reader.read( + self.READ_NUM_BYTES, timeout, flag=self.timeout_flag + ) except OSError: # pragma: nocover data = b"" self.h11_state.receive_data(data) @@ -162,7 +172,9 @@ class HTTP11Connection: self.h11_state.our_state is h11.DONE and self.h11_state.their_state is h11.DONE ): + # Get ready for another request/response cycle. self.h11_state.start_next_cycle() + self.timeout_flag.set_write_timeouts() else: await self.close() diff --git a/http3/dispatch/http2.py b/http3/dispatch/http2.py index ae42b273..9bd35eaf 100644 --- a/http3/dispatch/http2.py +++ b/http3/dispatch/http2.py @@ -4,6 +4,7 @@ import typing import h2.connection import h2.events +from ..concurrency import TimeoutFlag from ..config import DEFAULT_TIMEOUT_CONFIG, TimeoutConfig, TimeoutTypes from ..exceptions import ConnectTimeout, ReadTimeout from ..interfaces import BaseReader, BaseWriter, ConcurrencyBackend @@ -26,6 +27,7 @@ class HTTP2Connection: self.on_release = on_release self.h2_state = h2.connection.H2Connection() self.events = {} # type: typing.Dict[int, typing.List[h2.events.Event]] + self.timeout_flags = {} # type: typing.Dict[int, TimeoutFlag] self.initialized = False async def send( @@ -39,6 +41,7 @@ class HTTP2Connection: stream_id = await self.send_headers(request, timeout) self.events[stream_id] = [] + self.timeout_flags[stream_id] = TimeoutFlag() task, args = self.send_request_data, [stream_id, request.stream(), timeout] async with self.backend.background_manager(task, args=args): @@ -85,9 +88,13 @@ class HTTP2Connection: stream: typing.AsyncIterator[bytes], timeout: TimeoutConfig = None, ) -> None: - async for data in stream: - await self.send_data(stream_id, data, timeout) - await self.end_stream(stream_id, timeout) + try: + async for data in stream: + await self.send_data(stream_id, data, timeout) + await self.end_stream(stream_id, timeout) + finally: + # Once we've sent the request we should enable read timeouts. + self.timeout_flags[stream_id].set_read_timeouts() async def send_data( self, stream_id: int, data: bytes, timeout: TimeoutConfig = None @@ -113,6 +120,9 @@ class HTTP2Connection: """ while True: event = await self.receive_event(stream_id, timeout) + # As soon as we start seeing response events, we should enable + # read timeouts, if we haven't already. + self.timeout_flags[stream_id].set_read_timeouts() if isinstance(event, h2.events.ResponseReceived): break @@ -140,7 +150,8 @@ class HTTP2Connection: self, stream_id: int, timeout: TimeoutConfig = None ) -> h2.events.Event: while not self.events[stream_id]: - data = await self.reader.read(self.READ_NUM_BYTES, timeout) + flag = self.timeout_flags[stream_id] + data = await self.reader.read(self.READ_NUM_BYTES, timeout, flag=flag) events = self.h2_state.receive_data(data) for event in events: if getattr(event, "stream_id", 0): @@ -153,6 +164,7 @@ class HTTP2Connection: async def response_closed(self, stream_id: int) -> None: del self.events[stream_id] + del self.timeout_flags[stream_id] if not self.events and self.on_release is not None: await self.on_release() diff --git a/http3/interfaces.py b/http3/interfaces.py index 23126397..5d9b99c7 100644 --- a/http3/interfaces.py +++ b/http3/interfaces.py @@ -127,7 +127,9 @@ class BaseReader: backend, or for stand-alone test cases. """ - async def read(self, n: int, timeout: TimeoutConfig = None) -> bytes: + async def read( + self, n: int, timeout: TimeoutConfig = None, flag: typing.Any = None + ) -> bytes: raise NotImplementedError() # pragma: no cover diff --git a/tests/dispatch/utils.py b/tests/dispatch/utils.py index cdb7c031..4764f318 100644 --- a/tests/dispatch/utils.py +++ b/tests/dispatch/utils.py @@ -45,7 +45,7 @@ class MockHTTP2Server(BaseReader, BaseWriter): # BaseReader interface - async def read(self, n, timeout) -> bytes: + async def read(self, n, timeout, flag=None) -> bytes: await asyncio.sleep(0) send, self.buffer = self.buffer[:n], self.buffer[n:] return send