From: Tom Christie Date: Fri, 29 Nov 2019 12:01:51 +0000 (+0000) Subject: Pool timeouts should be on the TimeoutConfig, not PoolLimits (#563) X-Git-Tag: 0.9.0~41 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=364378a814981044484ea68960be7860e75740f0;p=thirdparty%2Fhttpx.git Pool timeouts should be on the TimeoutConfig, not PoolLimits (#563) * Pool timeouts should be on the TimeoutConfig, not PoolLimits * Linting * Fix type annotation * Linting --- diff --git a/httpx/concurrency/asyncio.py b/httpx/concurrency/asyncio.py index e7ed63ab..c3fcae0a 100644 --- a/httpx/concurrency/asyncio.py +++ b/httpx/concurrency/asyncio.py @@ -226,11 +226,10 @@ class PoolSemaphore(BasePoolSemaphore): self._semaphore = asyncio.BoundedSemaphore(value=max_connections) return self._semaphore - async def acquire(self) -> None: + async def acquire(self, timeout: float = None) -> None: if self.semaphore is None: return - timeout = self.pool_limits.pool_timeout try: await asyncio.wait_for(self.semaphore.acquire(), timeout) except asyncio.TimeoutError: diff --git a/httpx/concurrency/base.py b/httpx/concurrency/base.py index 172e7c2d..6bbeb071 100644 --- a/httpx/concurrency/base.py +++ b/httpx/concurrency/base.py @@ -95,7 +95,7 @@ class BasePoolSemaphore: Abstracts away any asyncio-specific interfaces. """ - async def acquire(self) -> None: + async def acquire(self, timeout: float = None) -> None: raise NotImplementedError() # pragma: no cover def release(self) -> None: diff --git a/httpx/concurrency/trio.py b/httpx/concurrency/trio.py index ed72d215..0169f544 100644 --- a/httpx/concurrency/trio.py +++ b/httpx/concurrency/trio.py @@ -151,11 +151,11 @@ class PoolSemaphore(BasePoolSemaphore): ) return self._semaphore - async def acquire(self) -> None: + async def acquire(self, timeout: float = None) -> None: if self.semaphore is None: return - timeout = _or_inf(self.pool_limits.pool_timeout) + timeout = _or_inf(timeout) with trio.move_on_after(timeout): await self.semaphore.acquire() diff --git a/httpx/config.py b/httpx/config.py index 910f91b4..a32ae84d 100644 --- a/httpx/config.py +++ b/httpx/config.py @@ -10,7 +10,9 @@ from .utils import get_ca_bundle_from_env, get_logger CertTypes = typing.Union[str, typing.Tuple[str, str], typing.Tuple[str, str, str]] VerifyTypes = typing.Union[str, bool, ssl.SSLContext] -TimeoutTypes = typing.Union[float, typing.Tuple[float, float, float], "TimeoutConfig"] +TimeoutTypes = typing.Union[ + float, typing.Tuple[float, float, float, float], "TimeoutConfig" +] HTTPVersionTypes = typing.Union[ str, typing.List[str], typing.Tuple[str], "HTTPVersionConfig" ] @@ -227,28 +229,34 @@ class TimeoutConfig: connect_timeout: float = None, read_timeout: float = None, write_timeout: float = None, + pool_timeout: float = None, ): if timeout is None: self.connect_timeout = connect_timeout self.read_timeout = read_timeout self.write_timeout = write_timeout + self.pool_timeout = pool_timeout else: # Specified as a single timeout value assert connect_timeout is None assert read_timeout is None assert write_timeout is None + assert pool_timeout is None if isinstance(timeout, TimeoutConfig): self.connect_timeout = timeout.connect_timeout self.read_timeout = timeout.read_timeout self.write_timeout = timeout.write_timeout + self.pool_timeout = timeout.pool_timeout elif isinstance(timeout, tuple): self.connect_timeout = timeout[0] self.read_timeout = timeout[1] - self.write_timeout = timeout[2] + self.write_timeout = None if len(timeout) < 3 else timeout[2] + self.pool_timeout = None if len(timeout) < 4 else timeout[3] else: self.connect_timeout = timeout self.read_timeout = timeout self.write_timeout = timeout + self.pool_timeout = timeout def __eq__(self, other: typing.Any) -> bool: return ( @@ -256,15 +264,27 @@ class TimeoutConfig: and self.connect_timeout == other.connect_timeout and self.read_timeout == other.read_timeout and self.write_timeout == other.write_timeout + and self.pool_timeout == other.pool_timeout ) def __repr__(self) -> str: class_name = self.__class__.__name__ - if len({self.connect_timeout, self.read_timeout, self.write_timeout}) == 1: + if ( + len( + { + self.connect_timeout, + self.read_timeout, + self.write_timeout, + self.pool_timeout, + } + ) + == 1 + ): return f"{class_name}(timeout={self.connect_timeout})" return ( f"{class_name}(connect_timeout={self.connect_timeout}, " - f"read_timeout={self.read_timeout}, write_timeout={self.write_timeout})" + f"read_timeout={self.read_timeout}, write_timeout={self.write_timeout}, " + f"pool_timeout={self.pool_timeout})" ) @@ -320,34 +340,27 @@ class PoolLimits: """ def __init__( - self, - *, - soft_limit: int = None, - hard_limit: int = None, - pool_timeout: float = None, + self, *, soft_limit: int = None, hard_limit: int = None, ): self.soft_limit = soft_limit self.hard_limit = hard_limit - self.pool_timeout = pool_timeout def __eq__(self, other: typing.Any) -> bool: return ( isinstance(other, self.__class__) and self.soft_limit == other.soft_limit and self.hard_limit == other.hard_limit - and self.pool_timeout == other.pool_timeout ) def __repr__(self) -> str: class_name = self.__class__.__name__ return ( - f"{class_name}(soft_limit={self.soft_limit}, " - f"hard_limit={self.hard_limit}, pool_timeout={self.pool_timeout})" + f"{class_name}(soft_limit={self.soft_limit}, hard_limit={self.hard_limit})" ) DEFAULT_SSL_CONFIG = SSLConfig(cert=None, verify=True) DEFAULT_TIMEOUT_CONFIG = TimeoutConfig(timeout=5.0) -DEFAULT_POOL_LIMITS = PoolLimits(soft_limit=10, hard_limit=100, pool_timeout=5.0) +DEFAULT_POOL_LIMITS = PoolLimits(soft_limit=10, hard_limit=100) DEFAULT_CA_BUNDLE_PATH = Path(certifi.where()) DEFAULT_MAX_REDIRECTS = 20 diff --git a/httpx/dispatch/connection_pool.py b/httpx/dispatch/connection_pool.py index d64a1930..32173301 100644 --- a/httpx/dispatch/connection_pool.py +++ b/httpx/dispatch/connection_pool.py @@ -8,6 +8,7 @@ from ..config import ( CertTypes, HTTPVersionTypes, PoolLimits, + TimeoutConfig, TimeoutTypes, VerifyTypes, ) @@ -93,7 +94,7 @@ class ConnectionPool(Dispatcher): ): self.verify = verify self.cert = cert - self.timeout = timeout + self.timeout = TimeoutConfig(timeout) self.pool_limits = pool_limits self.http_versions = http_versions self.is_closed = False @@ -117,7 +118,9 @@ class ConnectionPool(Dispatcher): cert: CertTypes = None, timeout: TimeoutTypes = None, ) -> Response: - connection = await self.acquire_connection(origin=request.url.origin) + connection = await self.acquire_connection( + origin=request.url.origin, timeout=timeout + ) try: response = await connection.send( request, verify=verify, cert=cert, timeout=timeout @@ -129,12 +132,19 @@ class ConnectionPool(Dispatcher): return response - async def acquire_connection(self, origin: Origin) -> HTTPConnection: + async def acquire_connection( + self, origin: Origin, timeout: TimeoutTypes = None + ) -> HTTPConnection: logger.trace(f"acquire_connection origin={origin!r}") connection = self.pop_connection(origin) if connection is None: - await self.max_connections.acquire() + if timeout is None: + pool_timeout = self.timeout.pool_timeout + else: + pool_timeout = TimeoutConfig(timeout).pool_timeout + + await self.max_connections.acquire(timeout=pool_timeout) connection = HTTPConnection( origin, verify=self.verify, diff --git a/httpx/dispatch/proxy_http.py b/httpx/dispatch/proxy_http.py index 54516348..13eb4233 100644 --- a/httpx/dispatch/proxy_http.py +++ b/httpx/dispatch/proxy_http.py @@ -81,12 +81,14 @@ class HTTPProxy(ConnectionPool): token = b64encode(b":".join(userpass)).decode().strip() return f"Basic {token}" - async def acquire_connection(self, origin: Origin) -> HTTPConnection: + async def acquire_connection( + self, origin: Origin, timeout: TimeoutTypes = None + ) -> HTTPConnection: if self.should_forward_origin(origin): logger.trace( f"forward_connection proxy_url={self.proxy_url!r} origin={origin!r}" ) - return await super().acquire_connection(self.proxy_url.origin) + return await super().acquire_connection(self.proxy_url.origin, timeout) else: logger.trace( f"tunnel_connection proxy_url={self.proxy_url!r} origin={origin!r}" diff --git a/tests/test_config.py b/tests/test_config.py index 498116cf..d8d429d8 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -160,9 +160,7 @@ def test_empty_http_version(): def test_limits_repr(): limits = httpx.PoolLimits(hard_limit=100) - assert ( - repr(limits) == "PoolLimits(soft_limit=None, hard_limit=100, pool_timeout=None)" - ) + assert repr(limits) == "PoolLimits(soft_limit=None, hard_limit=100)" def test_ssl_eq(): @@ -185,6 +183,7 @@ def test_timeout_from_nothing(): assert timeout.connect_timeout is None assert timeout.read_timeout is None assert timeout.write_timeout is None + assert timeout.pool_timeout is None def test_timeout_from_none(): @@ -198,7 +197,7 @@ def test_timeout_from_one_none_value(): def test_timeout_from_tuple(): - timeout = httpx.TimeoutConfig(timeout=(5.0, 5.0, 5.0)) + timeout = httpx.TimeoutConfig(timeout=(5.0, 5.0, 5.0, 5.0)) assert timeout == httpx.TimeoutConfig(timeout=5.0) @@ -212,9 +211,9 @@ def test_timeout_repr(): assert repr(timeout) == "TimeoutConfig(timeout=5.0)" timeout = httpx.TimeoutConfig(read_timeout=5.0) - assert ( - repr(timeout) - == "TimeoutConfig(connect_timeout=None, read_timeout=5.0, write_timeout=None)" + assert repr(timeout) == ( + "TimeoutConfig(connect_timeout=None, read_timeout=5.0, " + "write_timeout=None, pool_timeout=None)" ) diff --git a/tests/test_timeouts.py b/tests/test_timeouts.py index bd6bf8c4..89cdc1e0 100644 --- a/tests/test_timeouts.py +++ b/tests/test_timeouts.py @@ -38,9 +38,12 @@ async def test_connect_timeout(server, backend): async def test_pool_timeout(server, backend): - pool_limits = PoolLimits(hard_limit=1, pool_timeout=1e-4) + pool_limits = PoolLimits(hard_limit=1) + timeout = TimeoutConfig(pool_timeout=1e-4) - async with Client(pool_limits=pool_limits, backend=backend) as client: + async with Client( + pool_limits=pool_limits, timeout=timeout, backend=backend + ) as client: response = await client.get(server.url, stream=True) with pytest.raises(PoolTimeout):