]> git.ipfire.org Git - thirdparty/httpx.git/commitdiff
Add soft_limit and hard_limit support to PoolLimits 4/head
authorTom Christie <tom@tomchristie.com>
Wed, 17 Apr 2019 15:54:18 +0000 (16:54 +0100)
committerTom Christie <tom@tomchristie.com>
Wed, 17 Apr 2019 15:54:18 +0000 (16:54 +0100)
httpcore/config.py
httpcore/connections.py
httpcore/pool.py
tests/test_pool.py

index d169e0afb7ba22368aacec235a18e78f955e1b41..2db89342705c9effbe889da74ae17905d387282c 100644 (file)
@@ -45,13 +45,17 @@ class PoolLimits:
     Limits on the number of connections in a connection pool.
     """
 
-    def __init__(self, *, max_hosts: int, conns_per_host: int, hard_limit: bool):
-        self.max_hosts = max_hosts
-        self.conns_per_host = conns_per_host
+    def __init__(
+        self,
+        *,
+        soft_limit: typing.Optional[int] = None,
+        hard_limit: typing.Optional[int] = None
+    ):
+        self.soft_limit = soft_limit
         self.hard_limit = hard_limit
 
 
 DEFAULT_SSL_CONFIG = SSLConfig(cert=None, verify=True)
 DEFAULT_TIMEOUT_CONFIG = TimeoutConfig(timeout=5.0)
-DEFAULT_POOL_LIMITS = PoolLimits(max_hosts=10, conns_per_host=10, hard_limit=False)
+DEFAULT_POOL_LIMITS = PoolLimits(soft_limit=10, hard_limit=100)
 DEFAULT_CA_BUNDLE_PATH = certifi.where()
index 8d5c13bd76fdf2c05c5e0e74a409d9e757136c49..205482e8a697c00e72c5836177c987561582f0bf 100644 (file)
@@ -111,19 +111,21 @@ class Connection:
         if self.state.our_state is h11.DONE and self.state.their_state is h11.DONE:
             self.state.start_next_cycle()
         else:
-            event = h11.ConnectionClosed()
-            try:
-                # If we're in h11.MUST_CLOSE then we'll end up in h11.CLOSED.
-                self.state.send(event)
-            except h11.ProtocolError:
-                # If we're in some other state then it's a premature close,
-                # and we'll end up in h11.ERROR.
-                pass
-
-        if self.is_closed:
-            self.writer.close()
-            if hasattr(self.writer, "wait_closed"):
-                await self.writer.wait_closed()
+            self.close()
 
         if self.on_release is not None:
             await self.on_release(self)
+
+    def close(self) -> None:
+        assert self.writer is not None
+
+        event = h11.ConnectionClosed()
+        try:
+            # If we're in h11.MUST_CLOSE then we'll end up in h11.CLOSED.
+            self.state.send(event)
+        except h11.ProtocolError:
+            # If we're in some other state then it's a premature close,
+            # and we'll end up in h11.ERROR.
+            pass
+
+        self.writer.close()
index 6b4d328ada5c6919a50f3f13a37c5c49c5afa5b0..fcec56ea4c3170eb15b884b683e038898ee119b4 100644 (file)
@@ -20,6 +20,20 @@ from .datastructures import URL, Request, Response
 ConnectionKey = typing.Tuple[str, str, int]  # (scheme, host, port)
 
 
+class ConnectionSemaphore:
+    def __init__(self, max_connections: int = None):
+        if max_connections is not None:
+            self.semaphore = asyncio.BoundedSemaphore(value=max_connections)
+
+    async def acquire(self) -> None:
+        if hasattr(self, "semaphore"):
+            await self.semaphore.acquire()
+
+    def release(self) -> None:
+        if hasattr(self, "semaphore"):
+            self.semaphore.release()
+
+
 class ConnectionPool:
     def __init__(
         self,
@@ -37,6 +51,9 @@ class ConnectionPool:
         self._connections = (
             {}
         )  # type: typing.Dict[ConnectionKey, typing.List[Connection]]
+        self._connection_semaphore = ConnectionSemaphore(
+            max_connections=self.limits.hard_limit
+        )
 
     async def request(
         self,
@@ -59,6 +76,10 @@ class ConnectionPool:
                 await response.close()
         return response
 
+    @property
+    def num_connections(self) -> int:
+        return self.num_active_connections + self.num_keepalive_connections
+
     async def acquire_connection(
         self, url: URL, *, ssl: typing.Optional[ssl.SSLContext] = None
     ) -> Connection:
@@ -71,6 +92,7 @@ class ConnectionPool:
             self.num_active_connections += 1
 
         except (KeyError, IndexError):
+            await self._connection_semaphore.acquire()
             release = functools.partial(self.release_connection, key=key)
             connection = Connection(timeout=self.timeout, on_release=release)
             self.num_active_connections += 1
@@ -81,8 +103,18 @@ class ConnectionPool:
     async def release_connection(
         self, connection: Connection, key: ConnectionKey
     ) -> None:
-        self.num_active_connections -= 1
-        if not connection.is_closed:
+        if connection.is_closed:
+            self._connection_semaphore.release()
+            self.num_active_connections -= 1
+        elif (
+            self.limits.soft_limit is not None
+            and self.num_connections > self.limits.soft_limit
+        ):
+            self._connection_semaphore.release()
+            self.num_active_connections -= 1
+            connection.close()
+        else:
+            self.num_active_connections -= 1
             self.num_keepalive_connections += 1
             try:
                 self._connections[key].append(connection)
index 444d51c81edae9f5c9aefc5a92e1f488655cc5e2..77a221575441fda934bd4799ed8a4a33b00726f0 100644 (file)
@@ -33,6 +33,23 @@ async def test_differing_connection_keys(server):
         assert http.num_keepalive_connections == 2
 
 
+@pytest.mark.asyncio
+async def test_soft_limit(server):
+    """
+    The soft_limit config should limit the maximum number of keep-alive connections.
+    """
+    limits = httpcore.PoolLimits(soft_limit=1)
+
+    async with httpcore.ConnectionPool(limits=limits) as http:
+        response = await http.request("GET", "http://127.0.0.1:8000/")
+        assert http.num_active_connections == 0
+        assert http.num_keepalive_connections == 1
+
+        response = await http.request("GET", "http://localhost:8000/")
+        assert http.num_active_connections == 0
+        assert http.num_keepalive_connections == 1
+
+
 @pytest.mark.asyncio
 async def test_streaming_response_holds_connection(server):
     """