]> git.ipfire.org Git - thirdparty/httpx.git/commitdiff
Pool timeouts should be on the TimeoutConfig, not PoolLimits (#563)
authorTom Christie <tom@tomchristie.com>
Fri, 29 Nov 2019 12:01:51 +0000 (12:01 +0000)
committerGitHub <noreply@github.com>
Fri, 29 Nov 2019 12:01:51 +0000 (12:01 +0000)
* Pool timeouts should be on the TimeoutConfig, not PoolLimits

* Linting

* Fix type annotation

* Linting

httpx/concurrency/asyncio.py
httpx/concurrency/base.py
httpx/concurrency/trio.py
httpx/config.py
httpx/dispatch/connection_pool.py
httpx/dispatch/proxy_http.py
tests/test_config.py
tests/test_timeouts.py

index e7ed63abb388eac073b88f3b6d1550b14434fb3f..c3fcae0a7d38db620cbd933ce9af730d39dc1c84 100644 (file)
@@ -226,11 +226,10 @@ class PoolSemaphore(BasePoolSemaphore):
                 self._semaphore = asyncio.BoundedSemaphore(value=max_connections)
         return self._semaphore
 
-    async def acquire(self) -> None:
+    async def acquire(self, timeout: float = None) -> None:
         if self.semaphore is None:
             return
 
-        timeout = self.pool_limits.pool_timeout
         try:
             await asyncio.wait_for(self.semaphore.acquire(), timeout)
         except asyncio.TimeoutError:
index 172e7c2d5c87cc9f966a45b062f41b3b8315f4f5..6bbeb071ed541d651b8dd967ab4e9eb3cd7d78e5 100644 (file)
@@ -95,7 +95,7 @@ class BasePoolSemaphore:
     Abstracts away any asyncio-specific interfaces.
     """
 
-    async def acquire(self) -> None:
+    async def acquire(self, timeout: float = None) -> None:
         raise NotImplementedError()  # pragma: no cover
 
     def release(self) -> None:
index ed72d2150a5376246e848a735f50ba418feba879..0169f544c4111441fb99be48b3165b290e3ee9f1 100644 (file)
@@ -151,11 +151,11 @@ class PoolSemaphore(BasePoolSemaphore):
                 )
         return self._semaphore
 
-    async def acquire(self) -> None:
+    async def acquire(self, timeout: float = None) -> None:
         if self.semaphore is None:
             return
 
-        timeout = _or_inf(self.pool_limits.pool_timeout)
+        timeout = _or_inf(timeout)
 
         with trio.move_on_after(timeout):
             await self.semaphore.acquire()
index 910f91b4f8ee3b9cc5f6f1142bdb8c2a45075524..a32ae84df9599daf373c309cdc858f88746b049a 100644 (file)
@@ -10,7 +10,9 @@ from .utils import get_ca_bundle_from_env, get_logger
 
 CertTypes = typing.Union[str, typing.Tuple[str, str], typing.Tuple[str, str, str]]
 VerifyTypes = typing.Union[str, bool, ssl.SSLContext]
-TimeoutTypes = typing.Union[float, typing.Tuple[float, float, float], "TimeoutConfig"]
+TimeoutTypes = typing.Union[
+    float, typing.Tuple[float, float, float, float], "TimeoutConfig"
+]
 HTTPVersionTypes = typing.Union[
     str, typing.List[str], typing.Tuple[str], "HTTPVersionConfig"
 ]
@@ -227,28 +229,34 @@ class TimeoutConfig:
         connect_timeout: float = None,
         read_timeout: float = None,
         write_timeout: float = None,
+        pool_timeout: float = None,
     ):
         if timeout is None:
             self.connect_timeout = connect_timeout
             self.read_timeout = read_timeout
             self.write_timeout = write_timeout
+            self.pool_timeout = pool_timeout
         else:
             # Specified as a single timeout value
             assert connect_timeout is None
             assert read_timeout is None
             assert write_timeout is None
+            assert pool_timeout is None
             if isinstance(timeout, TimeoutConfig):
                 self.connect_timeout = timeout.connect_timeout
                 self.read_timeout = timeout.read_timeout
                 self.write_timeout = timeout.write_timeout
+                self.pool_timeout = timeout.pool_timeout
             elif isinstance(timeout, tuple):
                 self.connect_timeout = timeout[0]
                 self.read_timeout = timeout[1]
-                self.write_timeout = timeout[2]
+                self.write_timeout = None if len(timeout) < 3 else timeout[2]
+                self.pool_timeout = None if len(timeout) < 4 else timeout[3]
             else:
                 self.connect_timeout = timeout
                 self.read_timeout = timeout
                 self.write_timeout = timeout
+                self.pool_timeout = timeout
 
     def __eq__(self, other: typing.Any) -> bool:
         return (
@@ -256,15 +264,27 @@ class TimeoutConfig:
             and self.connect_timeout == other.connect_timeout
             and self.read_timeout == other.read_timeout
             and self.write_timeout == other.write_timeout
+            and self.pool_timeout == other.pool_timeout
         )
 
     def __repr__(self) -> str:
         class_name = self.__class__.__name__
-        if len({self.connect_timeout, self.read_timeout, self.write_timeout}) == 1:
+        if (
+            len(
+                {
+                    self.connect_timeout,
+                    self.read_timeout,
+                    self.write_timeout,
+                    self.pool_timeout,
+                }
+            )
+            == 1
+        ):
             return f"{class_name}(timeout={self.connect_timeout})"
         return (
             f"{class_name}(connect_timeout={self.connect_timeout}, "
-            f"read_timeout={self.read_timeout}, write_timeout={self.write_timeout})"
+            f"read_timeout={self.read_timeout}, write_timeout={self.write_timeout}, "
+            f"pool_timeout={self.pool_timeout})"
         )
 
 
@@ -320,34 +340,27 @@ class PoolLimits:
     """
 
     def __init__(
-        self,
-        *,
-        soft_limit: int = None,
-        hard_limit: int = None,
-        pool_timeout: float = None,
+        self, *, soft_limit: int = None, hard_limit: int = None,
     ):
         self.soft_limit = soft_limit
         self.hard_limit = hard_limit
-        self.pool_timeout = pool_timeout
 
     def __eq__(self, other: typing.Any) -> bool:
         return (
             isinstance(other, self.__class__)
             and self.soft_limit == other.soft_limit
             and self.hard_limit == other.hard_limit
-            and self.pool_timeout == other.pool_timeout
         )
 
     def __repr__(self) -> str:
         class_name = self.__class__.__name__
         return (
-            f"{class_name}(soft_limit={self.soft_limit}, "
-            f"hard_limit={self.hard_limit}, pool_timeout={self.pool_timeout})"
+            f"{class_name}(soft_limit={self.soft_limit}, hard_limit={self.hard_limit})"
         )
 
 
 DEFAULT_SSL_CONFIG = SSLConfig(cert=None, verify=True)
 DEFAULT_TIMEOUT_CONFIG = TimeoutConfig(timeout=5.0)
-DEFAULT_POOL_LIMITS = PoolLimits(soft_limit=10, hard_limit=100, pool_timeout=5.0)
+DEFAULT_POOL_LIMITS = PoolLimits(soft_limit=10, hard_limit=100)
 DEFAULT_CA_BUNDLE_PATH = Path(certifi.where())
 DEFAULT_MAX_REDIRECTS = 20
index d64a1930a8a5fb0f39515e0950d5cf87cefd68dd..32173301b713fdcfd1e35c3e702fd5a1cff248c9 100644 (file)
@@ -8,6 +8,7 @@ from ..config import (
     CertTypes,
     HTTPVersionTypes,
     PoolLimits,
+    TimeoutConfig,
     TimeoutTypes,
     VerifyTypes,
 )
@@ -93,7 +94,7 @@ class ConnectionPool(Dispatcher):
     ):
         self.verify = verify
         self.cert = cert
-        self.timeout = timeout
+        self.timeout = TimeoutConfig(timeout)
         self.pool_limits = pool_limits
         self.http_versions = http_versions
         self.is_closed = False
@@ -117,7 +118,9 @@ class ConnectionPool(Dispatcher):
         cert: CertTypes = None,
         timeout: TimeoutTypes = None,
     ) -> Response:
-        connection = await self.acquire_connection(origin=request.url.origin)
+        connection = await self.acquire_connection(
+            origin=request.url.origin, timeout=timeout
+        )
         try:
             response = await connection.send(
                 request, verify=verify, cert=cert, timeout=timeout
@@ -129,12 +132,19 @@ class ConnectionPool(Dispatcher):
 
         return response
 
-    async def acquire_connection(self, origin: Origin) -> HTTPConnection:
+    async def acquire_connection(
+        self, origin: Origin, timeout: TimeoutTypes = None
+    ) -> HTTPConnection:
         logger.trace(f"acquire_connection origin={origin!r}")
         connection = self.pop_connection(origin)
 
         if connection is None:
-            await self.max_connections.acquire()
+            if timeout is None:
+                pool_timeout = self.timeout.pool_timeout
+            else:
+                pool_timeout = TimeoutConfig(timeout).pool_timeout
+
+            await self.max_connections.acquire(timeout=pool_timeout)
             connection = HTTPConnection(
                 origin,
                 verify=self.verify,
index 54516348f8a4a908f2eb77f48a5b19098b698384..13eb42333e3162943412aae49218848580bd622b 100644 (file)
@@ -81,12 +81,14 @@ class HTTPProxy(ConnectionPool):
         token = b64encode(b":".join(userpass)).decode().strip()
         return f"Basic {token}"
 
-    async def acquire_connection(self, origin: Origin) -> HTTPConnection:
+    async def acquire_connection(
+        self, origin: Origin, timeout: TimeoutTypes = None
+    ) -> HTTPConnection:
         if self.should_forward_origin(origin):
             logger.trace(
                 f"forward_connection proxy_url={self.proxy_url!r} origin={origin!r}"
             )
-            return await super().acquire_connection(self.proxy_url.origin)
+            return await super().acquire_connection(self.proxy_url.origin, timeout)
         else:
             logger.trace(
                 f"tunnel_connection proxy_url={self.proxy_url!r} origin={origin!r}"
index 498116cf1d3c6e4860c98a54bccbee44d47a8fef..d8d429d811de1ec29c52f60d30c22a4237cc5bf3 100644 (file)
@@ -160,9 +160,7 @@ def test_empty_http_version():
 
 def test_limits_repr():
     limits = httpx.PoolLimits(hard_limit=100)
-    assert (
-        repr(limits) == "PoolLimits(soft_limit=None, hard_limit=100, pool_timeout=None)"
-    )
+    assert repr(limits) == "PoolLimits(soft_limit=None, hard_limit=100)"
 
 
 def test_ssl_eq():
@@ -185,6 +183,7 @@ def test_timeout_from_nothing():
     assert timeout.connect_timeout is None
     assert timeout.read_timeout is None
     assert timeout.write_timeout is None
+    assert timeout.pool_timeout is None
 
 
 def test_timeout_from_none():
@@ -198,7 +197,7 @@ def test_timeout_from_one_none_value():
 
 
 def test_timeout_from_tuple():
-    timeout = httpx.TimeoutConfig(timeout=(5.0, 5.0, 5.0))
+    timeout = httpx.TimeoutConfig(timeout=(5.0, 5.0, 5.0, 5.0))
     assert timeout == httpx.TimeoutConfig(timeout=5.0)
 
 
@@ -212,9 +211,9 @@ def test_timeout_repr():
     assert repr(timeout) == "TimeoutConfig(timeout=5.0)"
 
     timeout = httpx.TimeoutConfig(read_timeout=5.0)
-    assert (
-        repr(timeout)
-        == "TimeoutConfig(connect_timeout=None, read_timeout=5.0, write_timeout=None)"
+    assert repr(timeout) == (
+        "TimeoutConfig(connect_timeout=None, read_timeout=5.0, "
+        "write_timeout=None, pool_timeout=None)"
     )
 
 
index bd6bf8c4c4e4db820f83486fe5c87010f10440dd..89cdc1e0a0226b96da111e9d50c8eb7bf6776707 100644 (file)
@@ -38,9 +38,12 @@ async def test_connect_timeout(server, backend):
 
 
 async def test_pool_timeout(server, backend):
-    pool_limits = PoolLimits(hard_limit=1, pool_timeout=1e-4)
+    pool_limits = PoolLimits(hard_limit=1)
+    timeout = TimeoutConfig(pool_timeout=1e-4)
 
-    async with Client(pool_limits=pool_limits, backend=backend) as client:
+    async with Client(
+        pool_limits=pool_limits, timeout=timeout, backend=backend
+    ) as client:
         response = await client.get(server.url, stream=True)
 
         with pytest.raises(PoolTimeout):