MonkeyPatch.write = _fixed_write
+class TimeoutFlag:
+ """
+ A timeout flag holds a state of either read-timeout or write-timeout mode.
+
+ We use this so that we can attempt both reads and writes concurrently, while
+ only enforcing timeouts in one direction.
+
+ During a request/response cycle we start in write-timeout mode.
+
+ Once we've sent a request fully, or once we start seeing a response,
+ then we switch to read-timeout mode instead.
+ """
+
+ def __init__(self) -> None:
+ self.raise_on_read_timeout = False
+ self.raise_on_write_timeout = True
+
+ def set_read_timeouts(self) -> None:
+ """
+ Set the flag to read-timeout mode.
+ """
+ self.raise_on_read_timeout = True
+ self.raise_on_write_timeout = False
+
+ def set_write_timeouts(self) -> None:
+ """
+ Set the flag to write-timeout mode.
+ """
+ self.raise_on_read_timeout = False
+ self.raise_on_write_timeout = True
+
+
class Reader(BaseReader):
def __init__(
self, stream_reader: asyncio.StreamReader, timeout: TimeoutConfig
self.stream_reader = stream_reader
self.timeout = timeout
- async def read(self, n: int, timeout: TimeoutConfig = None) -> bytes:
+ async def read(
+ self, n: int, timeout: TimeoutConfig = None, flag: TimeoutFlag = None
+ ) -> bytes:
if timeout is None:
timeout = self.timeout
- try:
- data = await asyncio.wait_for(
- self.stream_reader.read(n), timeout.read_timeout
- )
- except asyncio.TimeoutError:
- raise ReadTimeout()
+ while True:
+ should_raise = flag is None or flag.raise_on_read_timeout
+ try:
+ data = await asyncio.wait_for(
+ self.stream_reader.read(n), timeout.read_timeout
+ )
+ break
+ except asyncio.TimeoutError:
+ if should_raise:
+ raise ReadTimeout()
return data
def write_no_block(self, data: bytes) -> None:
self.stream_writer.write(data) # pragma: nocover
- async def write(self, data: bytes, timeout: TimeoutConfig = None) -> None:
+ async def write(
+ self, data: bytes, timeout: TimeoutConfig = None, flag: TimeoutFlag = None
+ ) -> None:
if not data:
return
timeout = self.timeout
self.stream_writer.write(data)
- try:
- await asyncio.wait_for( # type: ignore
- self.stream_writer.drain(), timeout.write_timeout
- )
- except asyncio.TimeoutError:
- raise WriteTimeout()
+ while True:
+ try:
+ await asyncio.wait_for( # type: ignore
+ self.stream_writer.drain(), timeout.write_timeout
+ )
+ break
+ except asyncio.TimeoutError:
+ should_raise = flag is None or flag.raise_on_write_timeout
+ if should_raise:
+ raise WriteTimeout()
async def close(self) -> None:
self.stream_writer.close()
import h11
+from ..concurrency import TimeoutFlag
from ..config import DEFAULT_TIMEOUT_CONFIG, TimeoutConfig, TimeoutTypes
from ..exceptions import ConnectTimeout, ReadTimeout
from ..interfaces import BaseReader, BaseWriter, ConcurrencyBackend
self.backend = backend
self.on_release = on_release
self.h11_state = h11.Connection(our_role=h11.CLIENT)
+ self.timeout_flag = TimeoutFlag()
async def send(
self, request: AsyncRequest, timeout: TimeoutTypes = None
# care about connection errors that occur when sending the body.
# Ignore these, and defer to any exceptions on reading the response.
self.h11_state.send_failed()
+ finally:
+ # Once we've sent the request, we enable read timeouts.
+ self.timeout_flag.set_read_timeouts()
async def _send_event(self, event: H11Event, timeout: TimeoutConfig = None) -> None:
"""
"""
while True:
event = await self._receive_event(timeout)
+ # As soon as we start seeing response events, we should enable
+ # read timeouts, if we haven't already.
+ self.timeout_flag.set_read_timeouts()
if isinstance(event, h11.InformationalResponse):
continue
else:
event = self.h11_state.next_event()
if event is h11.NEED_DATA:
try:
- data = await self.reader.read(self.READ_NUM_BYTES, timeout)
+ data = await self.reader.read(
+ self.READ_NUM_BYTES, timeout, flag=self.timeout_flag
+ )
except OSError: # pragma: nocover
data = b""
self.h11_state.receive_data(data)
self.h11_state.our_state is h11.DONE
and self.h11_state.their_state is h11.DONE
):
+ # Get ready for another request/response cycle.
self.h11_state.start_next_cycle()
+ self.timeout_flag.set_write_timeouts()
else:
await self.close()
import h2.connection
import h2.events
+from ..concurrency import TimeoutFlag
from ..config import DEFAULT_TIMEOUT_CONFIG, TimeoutConfig, TimeoutTypes
from ..exceptions import ConnectTimeout, ReadTimeout
from ..interfaces import BaseReader, BaseWriter, ConcurrencyBackend
self.on_release = on_release
self.h2_state = h2.connection.H2Connection()
self.events = {} # type: typing.Dict[int, typing.List[h2.events.Event]]
+ self.timeout_flags = {} # type: typing.Dict[int, TimeoutFlag]
self.initialized = False
async def send(
stream_id = await self.send_headers(request, timeout)
self.events[stream_id] = []
+ self.timeout_flags[stream_id] = TimeoutFlag()
task, args = self.send_request_data, [stream_id, request.stream(), timeout]
async with self.backend.background_manager(task, args=args):
stream: typing.AsyncIterator[bytes],
timeout: TimeoutConfig = None,
) -> None:
- async for data in stream:
- await self.send_data(stream_id, data, timeout)
- await self.end_stream(stream_id, timeout)
+ try:
+ async for data in stream:
+ await self.send_data(stream_id, data, timeout)
+ await self.end_stream(stream_id, timeout)
+ finally:
+ # Once we've sent the request we should enable read timeouts.
+ self.timeout_flags[stream_id].set_read_timeouts()
async def send_data(
self, stream_id: int, data: bytes, timeout: TimeoutConfig = None
"""
while True:
event = await self.receive_event(stream_id, timeout)
+ # As soon as we start seeing response events, we should enable
+ # read timeouts, if we haven't already.
+ self.timeout_flags[stream_id].set_read_timeouts()
if isinstance(event, h2.events.ResponseReceived):
break
self, stream_id: int, timeout: TimeoutConfig = None
) -> h2.events.Event:
while not self.events[stream_id]:
- data = await self.reader.read(self.READ_NUM_BYTES, timeout)
+ flag = self.timeout_flags[stream_id]
+ data = await self.reader.read(self.READ_NUM_BYTES, timeout, flag=flag)
events = self.h2_state.receive_data(data)
for event in events:
if getattr(event, "stream_id", 0):
async def response_closed(self, stream_id: int) -> None:
del self.events[stream_id]
+ del self.timeout_flags[stream_id]
if not self.events and self.on_release is not None:
await self.on_release()
backend, or for stand-alone test cases.
"""
- async def read(self, n: int, timeout: TimeoutConfig = None) -> bytes:
+ async def read(
+ self, n: int, timeout: TimeoutConfig = None, flag: typing.Any = None
+ ) -> bytes:
raise NotImplementedError() # pragma: no cover
# BaseReader interface
- async def read(self, n, timeout) -> bytes:
+ async def read(self, n, timeout, flag=None) -> bytes:
await asyncio.sleep(0)
send, self.buffer = self.buffer[:n], self.buffer[n:]
return send