from ..config import PoolLimits, Timeout
from ..exceptions import ConnectTimeout, PoolTimeout, ReadTimeout, WriteTimeout
-from .base import (
- BaseEvent,
- BasePoolSemaphore,
- BaseSocketStream,
- ConcurrencyBackend,
- TimeoutFlag,
-)
+from .base import BaseEvent, BasePoolSemaphore, BaseSocketStream, ConcurrencyBackend
SSL_MONKEY_PATCH_APPLIED = False
ident = ssl_object.selected_alpn_protocol()
return "HTTP/2" if ident == "h2" else "HTTP/1.1"
- async def read(self, n: int, timeout: Timeout, flag: TimeoutFlag = None) -> bytes:
- while True:
- # Check our flag at the first possible moment, and use a fine
- # grained retry loop if we're not yet in read-timeout mode.
- should_raise = flag is None or flag.raise_on_read_timeout
- read_timeout = timeout.read_timeout if should_raise else 0.01
- try:
- async with self.read_lock:
- data = await asyncio.wait_for(
- self.stream_reader.read(n), read_timeout
- )
- except asyncio.TimeoutError:
- if should_raise:
- raise ReadTimeout() from None
- # FIX(py3.6): yield control back to the event loop to give it a chance
- # to cancel `.read(n)` before we retry.
- # This prevents concurrent `.read()` calls, which asyncio
- # doesn't seem to allow on 3.6.
- # See: https://github.com/encode/httpx/issues/382
- await asyncio.sleep(0)
- else:
- break
-
- return data
+ async def read(self, n: int, timeout: Timeout) -> bytes:
+ try:
+ async with self.read_lock:
+ return await asyncio.wait_for(
+ self.stream_reader.read(n), timeout.read_timeout
+ )
+ except asyncio.TimeoutError:
+ raise ReadTimeout() from None
- async def write(
- self, data: bytes, timeout: Timeout, flag: TimeoutFlag = None
- ) -> None:
+ async def write(self, data: bytes, timeout: Timeout) -> None:
if not data:
return
self.stream_writer.write(data)
- while True:
- try:
- await asyncio.wait_for( # type: ignore
- self.stream_writer.drain(), timeout.write_timeout
- )
- break
- except asyncio.TimeoutError:
- # We check our flag at the first possible moment, in order to
- # allow us to suppress write timeouts, if we've since
- # switched over to read-timeout mode.
- should_raise = flag is None or flag.raise_on_write_timeout
- if should_raise:
- raise WriteTimeout() from None
+ try:
+ return await asyncio.wait_for(
+ self.stream_writer.drain(), timeout.write_timeout
+ )
+ except asyncio.TimeoutError:
+ raise WriteTimeout() from None
def is_connection_dropped(self) -> bool:
# Counter-intuitively, what we really want to know here is whether the socket is
raise RuntimeError(f"Unknown or unsupported concurrency backend {backend!r}")
-class TimeoutFlag:
- """
- A timeout flag holds a state of either read-timeout or write-timeout mode.
-
- We use this so that we can attempt both reads and writes concurrently, while
- only enforcing timeouts in one direction.
-
- During a request/response cycle we start in write-timeout mode.
-
- Once we've sent a request fully, or once we start seeing a response,
- then we switch to read-timeout mode instead.
- """
-
- def __init__(self) -> None:
- self.raise_on_read_timeout = False
- self.raise_on_write_timeout = True
-
- def set_read_timeouts(self) -> None:
- """
- Set the flag to read-timeout mode.
- """
- self.raise_on_read_timeout = True
- self.raise_on_write_timeout = False
-
- def set_write_timeouts(self) -> None:
- """
- Set the flag to write-timeout mode.
- """
- self.raise_on_read_timeout = False
- self.raise_on_write_timeout = True
-
-
class BaseSocketStream:
"""
A socket stream with read/write operations. Abstracts away any asyncio-specific
) -> "BaseSocketStream":
raise NotImplementedError() # pragma: no cover
- async def read(self, n: int, timeout: Timeout, flag: typing.Any = None) -> bytes:
+ async def read(self, n: int, timeout: Timeout) -> bytes:
raise NotImplementedError() # pragma: no cover
async def write(self, data: bytes, timeout: Timeout) -> None:
from ..config import PoolLimits, Timeout
from ..exceptions import ConnectTimeout, PoolTimeout, ReadTimeout, WriteTimeout
-from .base import (
- BaseEvent,
- BasePoolSemaphore,
- BaseSocketStream,
- ConcurrencyBackend,
- TimeoutFlag,
-)
+from .base import BaseEvent, BasePoolSemaphore, BaseSocketStream, ConcurrencyBackend
-def _or_inf(value: typing.Optional[float]) -> float:
+def none_as_inf(value: typing.Optional[float]) -> float:
return value if value is not None else float("inf")
async def start_tls(
self, hostname: str, ssl_context: ssl.SSLContext, timeout: Timeout
) -> "SocketStream":
- connect_timeout = _or_inf(timeout.connect_timeout)
+ connect_timeout = none_as_inf(timeout.connect_timeout)
ssl_stream = trio.SSLStream(
self.stream, ssl_context=ssl_context, server_hostname=hostname
)
- with trio.move_on_after(connect_timeout) as cancel_scope:
+ with trio.move_on_after(connect_timeout):
await ssl_stream.do_handshake()
+ return SocketStream(ssl_stream)
- if cancel_scope.cancelled_caught:
- raise ConnectTimeout()
-
- return SocketStream(ssl_stream)
+ raise ConnectTimeout()
def get_http_version(self) -> str:
if not isinstance(self.stream, trio.SSLStream):
ident = self.stream.selected_alpn_protocol()
return "HTTP/2" if ident == "h2" else "HTTP/1.1"
- async def read(self, n: int, timeout: Timeout, flag: TimeoutFlag = None) -> bytes:
- while True:
- # Check our flag at the first possible moment, and use a fine
- # grained retry loop if we're not yet in read-timeout mode.
- should_raise = flag is None or flag.raise_on_read_timeout
- read_timeout = _or_inf(timeout.read_timeout if should_raise else 0.01)
+ async def read(self, n: int, timeout: Timeout) -> bytes:
+ read_timeout = none_as_inf(timeout.read_timeout)
+
+ with trio.move_on_after(read_timeout):
+ async with self.read_lock:
+ return await self.stream.receive_some(max_bytes=n)
+
+ raise ReadTimeout()
+
+ async def write(self, data: bytes, timeout: Timeout) -> None:
+ if not data:
+ return
- with trio.move_on_after(read_timeout):
- async with self.read_lock:
- return await self.stream.receive_some(max_bytes=n)
+ write_timeout = none_as_inf(timeout.write_timeout)
- if should_raise:
- raise ReadTimeout() from None
+ with trio.move_on_after(write_timeout):
+ async with self.write_lock:
+ return await self.stream.send_all(data)
+
+ raise WriteTimeout()
def is_connection_dropped(self) -> bool:
# Adapted from: https://github.com/encode/httpx/pull/143#issuecomment-515202982
# See: https://github.com/encode/httpx/pull/143#issuecomment-515181778
return stream.socket.is_readable()
- async def write(
- self, data: bytes, timeout: Timeout, flag: TimeoutFlag = None
- ) -> None:
- if not data:
- return
-
- write_timeout = _or_inf(timeout.write_timeout)
-
- while True:
- with trio.move_on_after(write_timeout):
- async with self.write_lock:
- await self.stream.send_all(data)
- break
- # We check our flag at the first possible moment, in order to
- # allow us to suppress write timeouts, if we've since
- # switched over to read-timeout mode.
- should_raise = flag is None or flag.raise_on_write_timeout
- if should_raise:
- raise WriteTimeout() from None
-
async def close(self) -> None:
await self.stream.aclose()
if self.semaphore is None:
return
- timeout = _or_inf(timeout)
+ timeout = none_as_inf(timeout)
with trio.move_on_after(timeout):
await self.semaphore.acquire()
ssl_context: typing.Optional[ssl.SSLContext],
timeout: Timeout,
) -> SocketStream:
- connect_timeout = _or_inf(timeout.connect_timeout)
+ connect_timeout = none_as_inf(timeout.connect_timeout)
- with trio.move_on_after(connect_timeout) as cancel_scope:
+ with trio.move_on_after(connect_timeout):
stream: trio.SocketStream = await trio.open_tcp_stream(hostname, port)
if ssl_context is not None:
stream = trio.SSLStream(stream, ssl_context, server_hostname=hostname)
await stream.do_handshake()
+ return SocketStream(stream=stream)
- if cancel_scope.cancelled_caught:
- raise ConnectTimeout()
-
- return SocketStream(stream=stream)
+ raise ConnectTimeout()
async def open_uds_stream(
self,
ssl_context: typing.Optional[ssl.SSLContext],
timeout: Timeout,
) -> SocketStream:
- connect_timeout = _or_inf(timeout.connect_timeout)
+ connect_timeout = none_as_inf(timeout.connect_timeout)
- with trio.move_on_after(connect_timeout) as cancel_scope:
+ with trio.move_on_after(connect_timeout):
stream: trio.SocketStream = await trio.open_unix_socket(path)
if ssl_context is not None:
stream = trio.SSLStream(stream, ssl_context, server_hostname=hostname)
await stream.do_handshake()
+ return SocketStream(stream=stream)
- if cancel_scope.cancelled_caught:
- raise ConnectTimeout()
-
- return SocketStream(stream=stream)
+ raise ConnectTimeout()
async def run_in_threadpool(
self, func: typing.Callable, *args: typing.Any, **kwargs: typing.Any
import h11
-from ..concurrency.base import BaseSocketStream, TimeoutFlag
+from ..concurrency.base import BaseSocketStream
from ..config import Timeout
from ..exceptions import ConnectionClosed, ProtocolError
from ..models import Request, Response
self.socket = socket
self.on_release = on_release
self.h11_state = h11.Connection(our_role=h11.CLIENT)
- self.timeout_flag = TimeoutFlag()
async def send(self, request: Request, timeout: Timeout = None) -> Response:
timeout = Timeout() if timeout is None else timeout
# care about connection errors that occur when sending the body.
# Ignore these, and defer to any exceptions on reading the response.
self.h11_state.send_failed()
- finally:
- # Once we've sent the request, we enable read timeouts.
- self.timeout_flag.set_read_timeouts()
async def _send_event(self, event: H11Event, timeout: Timeout) -> None:
"""
"""
while True:
event = await self._receive_event(timeout)
- # As soon as we start seeing response events, we should enable
- # read timeouts, if we haven't already.
- self.timeout_flag.set_read_timeouts()
if isinstance(event, h11.InformationalResponse):
continue
else:
if event is h11.NEED_DATA:
try:
- data = await self.socket.read(
- self.READ_NUM_BYTES, timeout, flag=self.timeout_flag
- )
+ data = await self.socket.read(self.READ_NUM_BYTES, timeout)
except OSError: # pragma: nocover
data = b""
self.h11_state.receive_data(data)
):
# Get ready for another request/response cycle.
self.h11_state.start_next_cycle()
- self.timeout_flag.set_write_timeouts()
else:
await self.close()