Limits on the number of connections in a connection pool.
"""
- def __init__(self, *, max_hosts: int, conns_per_host: int, hard_limit: bool):
- self.max_hosts = max_hosts
- self.conns_per_host = conns_per_host
+ def __init__(
+ self,
+ *,
+ soft_limit: typing.Optional[int] = None,
+ hard_limit: typing.Optional[int] = None
+ ):
+ self.soft_limit = soft_limit
self.hard_limit = hard_limit
DEFAULT_SSL_CONFIG = SSLConfig(cert=None, verify=True)
DEFAULT_TIMEOUT_CONFIG = TimeoutConfig(timeout=5.0)
-DEFAULT_POOL_LIMITS = PoolLimits(max_hosts=10, conns_per_host=10, hard_limit=False)
+DEFAULT_POOL_LIMITS = PoolLimits(soft_limit=10, hard_limit=100)
DEFAULT_CA_BUNDLE_PATH = certifi.where()
if self.state.our_state is h11.DONE and self.state.their_state is h11.DONE:
self.state.start_next_cycle()
else:
- event = h11.ConnectionClosed()
- try:
- # If we're in h11.MUST_CLOSE then we'll end up in h11.CLOSED.
- self.state.send(event)
- except h11.ProtocolError:
- # If we're in some other state then it's a premature close,
- # and we'll end up in h11.ERROR.
- pass
-
- if self.is_closed:
- self.writer.close()
- if hasattr(self.writer, "wait_closed"):
- await self.writer.wait_closed()
+ self.close()
if self.on_release is not None:
await self.on_release(self)
+
+ def close(self) -> None:
+ assert self.writer is not None
+
+ event = h11.ConnectionClosed()
+ try:
+ # If we're in h11.MUST_CLOSE then we'll end up in h11.CLOSED.
+ self.state.send(event)
+ except h11.ProtocolError:
+ # If we're in some other state then it's a premature close,
+ # and we'll end up in h11.ERROR.
+ pass
+
+ self.writer.close()
ConnectionKey = typing.Tuple[str, str, int] # (scheme, host, port)
+class ConnectionSemaphore:
+ def __init__(self, max_connections: int = None):
+ if max_connections is not None:
+ self.semaphore = asyncio.BoundedSemaphore(value=max_connections)
+
+ async def acquire(self) -> None:
+ if hasattr(self, "semaphore"):
+ await self.semaphore.acquire()
+
+ def release(self) -> None:
+ if hasattr(self, "semaphore"):
+ self.semaphore.release()
+
+
class ConnectionPool:
def __init__(
self,
self._connections = (
{}
) # type: typing.Dict[ConnectionKey, typing.List[Connection]]
+ self._connection_semaphore = ConnectionSemaphore(
+ max_connections=self.limits.hard_limit
+ )
async def request(
self,
await response.close()
return response
+ @property
+ def num_connections(self) -> int:
+ return self.num_active_connections + self.num_keepalive_connections
+
async def acquire_connection(
self, url: URL, *, ssl: typing.Optional[ssl.SSLContext] = None
) -> Connection:
self.num_active_connections += 1
except (KeyError, IndexError):
+ await self._connection_semaphore.acquire()
release = functools.partial(self.release_connection, key=key)
connection = Connection(timeout=self.timeout, on_release=release)
self.num_active_connections += 1
async def release_connection(
self, connection: Connection, key: ConnectionKey
) -> None:
- self.num_active_connections -= 1
- if not connection.is_closed:
+ if connection.is_closed:
+ self._connection_semaphore.release()
+ self.num_active_connections -= 1
+ elif (
+ self.limits.soft_limit is not None
+ and self.num_connections > self.limits.soft_limit
+ ):
+ self._connection_semaphore.release()
+ self.num_active_connections -= 1
+ connection.close()
+ else:
+ self.num_active_connections -= 1
self.num_keepalive_connections += 1
try:
self._connections[key].append(connection)
assert http.num_keepalive_connections == 2
+@pytest.mark.asyncio
+async def test_soft_limit(server):
+ """
+ The soft_limit config should limit the maximum number of keep-alive connections.
+ """
+ limits = httpcore.PoolLimits(soft_limit=1)
+
+ async with httpcore.ConnectionPool(limits=limits) as http:
+ response = await http.request("GET", "http://127.0.0.1:8000/")
+ assert http.num_active_connections == 0
+ assert http.num_keepalive_connections == 1
+
+ response = await http.request("GET", "http://localhost:8000/")
+ assert http.num_active_connections == 0
+ assert http.num_keepalive_connections == 1
+
+
@pytest.mark.asyncio
async def test_streaming_response_holds_connection(server):
"""