From: Tom Christie Date: Thu, 25 Apr 2019 14:31:47 +0000 (+0100) Subject: Tighten up connection acquiry/release X-Git-Tag: 0.3.0~66^2~15 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=28c505a70f080bca9fd9e5549a65c7dbdc9ce195;p=thirdparty%2Fhttpx.git Tighten up connection acquiry/release --- diff --git a/httpcore/connection.py b/httpcore/connection.py index f164232f..a07a062a 100644 --- a/httpcore/connection.py +++ b/httpcore/connection.py @@ -1,3 +1,4 @@ +import functools import typing import h2.connection @@ -17,12 +18,12 @@ class HTTPConnection(Client): origin: typing.Union[str, Origin], ssl: SSLConfig = DEFAULT_SSL_CONFIG, timeout: TimeoutConfig = DEFAULT_TIMEOUT_CONFIG, - on_release: typing.Callable = None, + pool_release_func: typing.Callable = None, ): self.origin = Origin(origin) if isinstance(origin, str) else origin self.ssl = ssl self.timeout = timeout - self.on_release = on_release + self.pool_release_func = pool_release_func self.h11_connection = None # type: typing.Optional[HTTP11Connection] self.h2_connection = None # type: typing.Optional[HTTP2Connection] @@ -43,6 +44,11 @@ class HTTPConnection(Client): port = self.origin.port ssl_context = await ssl.load_ssl_context() if self.origin.is_ssl else None + if self.pool_release_func is None: + on_release = None + else: + on_release = functools.partial(self.pool_release_func, self) + reader, writer, protocol = await connect( hostname, port, ssl_context, timeout ) @@ -52,7 +58,7 @@ class HTTPConnection(Client): writer, origin=self.origin, timeout=self.timeout, - on_release=self.on_release, + on_release=on_release, ) else: self.h11_connection = HTTP11Connection( @@ -60,7 +66,7 @@ class HTTPConnection(Client): writer, origin=self.origin, timeout=self.timeout, - on_release=self.on_release, + on_release=on_release, ) if self.h2_connection is not None: diff --git a/httpcore/connection_pool.py b/httpcore/connection_pool.py index 894212ab..b5414357 100644 --- a/httpcore/connection_pool.py +++ b/httpcore/connection_pool.py @@ -1,3 +1,4 @@ +import collections.abc import typing from .config import ( @@ -14,6 +15,66 @@ from .exceptions import PoolTimeout from .models import Client, Origin, Request, Response from .streams import PoolSemaphore +CONNECTIONS_DICT = typing.Dict[Origin, typing.List[HTTPConnection]] + + +class ConnectionStore(collections.abc.Sequence): + """ + We need to maintain collections of connections in a way that allows us to: + + * Lookup connections by origin. + * Iterate over connections by insertion time. + * Return the total number of connections. + """ + + def __init__(self) -> None: + self.all = {} # type: typing.Dict[HTTPConnection, float] + self.by_origin = ( + {} + ) # type: typing.Dict[Origin, typing.Dict[HTTPConnection, float]] + + def pop_by_origin(self, origin: Origin) -> typing.Optional[HTTPConnection]: + try: + connections = self.by_origin[origin] + except KeyError: + return None + + connection = next(reversed(list(connections.keys()))) + del connections[connection] + if not connections: + del self.by_origin[origin] + del self.all[connection] + + return connection + + def add(self, connection: HTTPConnection) -> None: + self.all[connection] = 0.0 + try: + self.by_origin[connection.origin][connection] = 0.0 + except KeyError: + self.by_origin[connection.origin] = {connection: 0.0} + + def remove(self, connection: HTTPConnection) -> None: + del self.all[connection] + del self.by_origin[connection.origin][connection] + if not self.by_origin[connection.origin]: + del self.by_origin[connection.origin] + + def clear(self) -> None: + self.all.clear() + self.by_origin.clear() + + def __iter__(self) -> typing.Iterator[HTTPConnection]: + return iter(self.all.keys()) + + def __getitem__(self, key: typing.Any) -> typing.Any: + if key in self.all: + return key + return None + + def __len__(self) -> int: + return len(self.all) + class ConnectionPool(Client): def __init__( @@ -27,12 +88,14 @@ class ConnectionPool(Client): self.timeout = timeout self.limits = limits self.is_closed = False - self.num_active_connections = 0 - self.num_keepalive_connections = 0 - self._keepalive_connections = ( - {} - ) # type: typing.Dict[Origin, typing.List[HTTPConnection]] - self._max_connections = PoolSemaphore(limits, timeout) + + self.max_connections = PoolSemaphore(limits, timeout) + self.keepalive_connections = ConnectionStore() + self.active_connections = ConnectionStore() + + @property + def num_connections(self) -> int: + return len(self.keepalive_connections) + len(self.active_connections) async def send( self, @@ -45,56 +108,42 @@ class ConnectionPool(Client): response = await connection.send(request, ssl=ssl, timeout=timeout) return response - @property - def num_connections(self) -> int: - return self.num_active_connections + self.num_keepalive_connections - async def acquire_connection( self, origin: Origin, timeout: typing.Optional[TimeoutConfig] = None ) -> HTTPConnection: - try: - connection = self._keepalive_connections[origin].pop() - if not self._keepalive_connections[origin]: - del self._keepalive_connections[origin] - self.num_keepalive_connections -= 1 - self.num_active_connections += 1 - - except (KeyError, IndexError): - await self._max_connections.acquire(timeout) + connection = self.keepalive_connections.pop_by_origin(origin) + + if connection is None: + await self.max_connections.acquire(timeout) connection = HTTPConnection( origin, ssl=self.ssl, timeout=self.timeout, - on_release=self.release_connection, + pool_release_func=self.release_connection, ) - self.num_active_connections += 1 + + self.active_connections.add(connection) return connection async def release_connection(self, connection: HTTPConnection) -> None: if connection.is_closed: - self._max_connections.release() - self.num_active_connections -= 1 + self.active_connections.remove(connection) + self.max_connections.release() elif ( self.limits.soft_limit is not None and self.num_connections > self.limits.soft_limit ): - self._max_connections.release() - self.num_active_connections -= 1 + self.active_connections.remove(connection) + self.max_connections.release() await connection.close() else: - self.num_active_connections -= 1 - self.num_keepalive_connections += 1 - try: - self._keepalive_connections[connection.origin].append(connection) - except KeyError: - self._keepalive_connections[connection.origin] = [connection] + self.active_connections.remove(connection) + self.keepalive_connections.add(connection) async def close(self) -> None: self.is_closed = True - all_connections = [] - for connections in self._keepalive_connections.values(): - all_connections.extend(list(connections)) - self._keepalive_connections.clear() - for connection in all_connections: + connections = list(self.keepalive_connections) + self.keepalive_connections.clear() + for connection in connections: await connection.close() diff --git a/httpcore/http11.py b/httpcore/http11.py index 253865fe..3280e1de 100644 --- a/httpcore/http11.py +++ b/httpcore/http11.py @@ -82,7 +82,7 @@ class HTTP11Connection(Client): protocol="HTTP/1.1", headers=headers, body=body, - on_close=self._release, + on_close=self.response_closed, ) async def _body_iter(self, timeout: OptionalTimeout) -> typing.AsyncIterator[bytes]: @@ -106,7 +106,7 @@ class HTTP11Connection(Client): return event - async def _release(self) -> None: + async def response_closed(self) -> None: if ( self.h11_state.our_state is h11.DONE and self.h11_state.their_state is h11.DONE @@ -116,7 +116,7 @@ class HTTP11Connection(Client): await self.close() if self.on_release is not None: - await self.on_release(self) + await self.on_release() async def close(self) -> None: event = h11.ConnectionClosed() diff --git a/httpcore/http2.py b/httpcore/http2.py index f8d2b648..2b6c8e6d 100644 --- a/httpcore/http2.py +++ b/httpcore/http2.py @@ -141,7 +141,7 @@ class HTTP2Connection(Client): async def release(self) -> None: if self.on_release is not None: - await self.on_release(self) + await self.on_release() async def close(self) -> None: await self.writer.close() diff --git a/httpcore/streams.py b/httpcore/streams.py index cba51fd7..e46ffee4 100644 --- a/httpcore/streams.py +++ b/httpcore/streams.py @@ -14,7 +14,7 @@ import ssl import typing from .config import DEFAULT_TIMEOUT_CONFIG, PoolLimits, TimeoutConfig -from .exceptions import ConnectTimeout, ReadTimeout, PoolTimeout, WriteTimeout +from .exceptions import ConnectTimeout, PoolTimeout, ReadTimeout, WriteTimeout OptionalTimeout = typing.Optional[TimeoutConfig] diff --git a/tests/test_connection_pools.py b/tests/test_connection_pools.py index 77a22157..7d478c5a 100644 --- a/tests/test_connection_pools.py +++ b/tests/test_connection_pools.py @@ -10,12 +10,12 @@ async def test_keepalive_connections(server): """ async with httpcore.ConnectionPool() as http: response = await http.request("GET", "http://127.0.0.1:8000/") - assert http.num_active_connections == 0 - assert http.num_keepalive_connections == 1 + assert len(http.active_connections) == 0 + assert len(http.keepalive_connections) == 1 response = await http.request("GET", "http://127.0.0.1:8000/") - assert http.num_active_connections == 0 - assert http.num_keepalive_connections == 1 + assert len(http.active_connections) == 0 + assert len(http.keepalive_connections) == 1 @pytest.mark.asyncio @@ -25,12 +25,12 @@ async def test_differing_connection_keys(server): """ async with httpcore.ConnectionPool() as http: response = await http.request("GET", "http://127.0.0.1:8000/") - assert http.num_active_connections == 0 - assert http.num_keepalive_connections == 1 + assert len(http.active_connections) == 0 + assert len(http.keepalive_connections) == 1 response = await http.request("GET", "http://localhost:8000/") - assert http.num_active_connections == 0 - assert http.num_keepalive_connections == 2 + assert len(http.active_connections) == 0 + assert len(http.keepalive_connections) == 2 @pytest.mark.asyncio @@ -42,12 +42,12 @@ async def test_soft_limit(server): async with httpcore.ConnectionPool(limits=limits) as http: response = await http.request("GET", "http://127.0.0.1:8000/") - assert http.num_active_connections == 0 - assert http.num_keepalive_connections == 1 + assert len(http.active_connections) == 0 + assert len(http.keepalive_connections) == 1 response = await http.request("GET", "http://localhost:8000/") - assert http.num_active_connections == 0 - assert http.num_keepalive_connections == 1 + assert len(http.active_connections) == 0 + assert len(http.keepalive_connections) == 1 @pytest.mark.asyncio @@ -57,13 +57,13 @@ async def test_streaming_response_holds_connection(server): """ async with httpcore.ConnectionPool() as http: response = await http.request("GET", "http://127.0.0.1:8000/", stream=True) - assert http.num_active_connections == 1 - assert http.num_keepalive_connections == 0 + assert len(http.active_connections) == 1 + assert len(http.keepalive_connections) == 0 await response.read() - assert http.num_active_connections == 0 - assert http.num_keepalive_connections == 1 + assert len(http.active_connections) == 0 + assert len(http.keepalive_connections) == 1 @pytest.mark.asyncio @@ -73,20 +73,20 @@ async def test_multiple_concurrent_connections(server): """ async with httpcore.ConnectionPool() as http: response_a = await http.request("GET", "http://127.0.0.1:8000/", stream=True) - assert http.num_active_connections == 1 - assert http.num_keepalive_connections == 0 + assert len(http.active_connections) == 1 + assert len(http.keepalive_connections) == 0 response_b = await http.request("GET", "http://127.0.0.1:8000/", stream=True) - assert http.num_active_connections == 2 - assert http.num_keepalive_connections == 0 + assert len(http.active_connections) == 2 + assert len(http.keepalive_connections) == 0 await response_b.read() - assert http.num_active_connections == 1 - assert http.num_keepalive_connections == 1 + assert len(http.active_connections) == 1 + assert len(http.keepalive_connections) == 1 await response_a.read() - assert http.num_active_connections == 0 - assert http.num_keepalive_connections == 2 + assert len(http.active_connections) == 0 + assert len(http.keepalive_connections) == 2 @pytest.mark.asyncio @@ -97,8 +97,8 @@ async def test_close_connections(server): headers = [(b"connection", b"close")] async with httpcore.ConnectionPool() as http: response = await http.request("GET", "http://127.0.0.1:8000/", headers=headers) - assert http.num_active_connections == 0 - assert http.num_keepalive_connections == 0 + assert len(http.active_connections) == 0 + assert len(http.keepalive_connections) == 0 @pytest.mark.asyncio @@ -110,8 +110,8 @@ async def test_standard_response_close(server): response = await http.request("GET", "http://127.0.0.1:8000/", stream=True) await response.read() await response.close() - assert http.num_active_connections == 0 - assert http.num_keepalive_connections == 1 + assert len(http.active_connections) == 0 + assert len(http.keepalive_connections) == 1 @pytest.mark.asyncio @@ -122,5 +122,5 @@ async def test_premature_response_close(server): async with httpcore.ConnectionPool() as http: response = await http.request("GET", "http://127.0.0.1:8000/", stream=True) await response.close() - assert http.num_active_connections == 0 - assert http.num_keepalive_connections == 0 + assert len(http.active_connections) == 0 + assert len(http.keepalive_connections) == 0