From: Tom Christie Date: Tue, 16 Apr 2019 14:58:38 +0000 (+0100) Subject: Connection pooling X-Git-Tag: 0.1.0~7^2~1 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=5f3a8bdb4d95490f0119a2d35c3a59cfb4711e81;p=thirdparty%2Fhttpx.git Connection pooling --- diff --git a/httpcore/connections.py b/httpcore/connections.py index f39a365f..8d5c13bd 100644 --- a/httpcore/connections.py +++ b/httpcore/connections.py @@ -19,18 +19,19 @@ H11Event = typing.Union[ class Connection: - def __init__(self, timeout: TimeoutConfig): + def __init__(self, timeout: TimeoutConfig, on_release: typing.Callable = None): self.reader = None self.writer = None self.state = h11.Connection(our_role=h11.CLIENT) self.timeout = timeout + self.on_release = on_release + + @property + def is_closed(self) -> bool: + return self.state.our_state in (h11.CLOSED, h11.ERROR) async def open( - self, - hostname: str, - port: int, - *, - ssl: typing.Union[bool, ssl.SSLContext] = False + self, hostname: str, port: int, *, ssl: typing.Optional[ssl.SSLContext] = None ) -> None: try: self.reader, self.writer = await asyncio.wait_for( # type: ignore @@ -69,18 +70,17 @@ class Connection: assert isinstance(event, h11.Response) status_code = event.status_code headers = event.headers - body = self.body_iter() + body = self._body_iter() return Response( - status_code=status_code, headers=headers, body=body, on_close=self.close + status_code=status_code, headers=headers, body=body, on_close=self._release ) - async def body_iter(self) -> typing.AsyncIterator[bytes]: + async def _body_iter(self) -> typing.AsyncIterator[bytes]: event = await self._receive_event() while isinstance(event, h11.Data): yield event.data event = await self._receive_event() assert isinstance(event, h11.EndOfMessage) - await self.close() async def _send_event(self, event: H11Event) -> None: assert self.writer is not None @@ -105,8 +105,25 @@ class Connection: return event - async def close(self) -> None: - if self.writer is not None: + async def _release(self) -> None: + assert self.writer is not None + + if self.state.our_state is h11.DONE and self.state.their_state is h11.DONE: + self.state.start_next_cycle() + else: + event = h11.ConnectionClosed() + try: + # If we're in h11.MUST_CLOSE then we'll end up in h11.CLOSED. + self.state.send(event) + except h11.ProtocolError: + # If we're in some other state then it's a premature close, + # and we'll end up in h11.ERROR. + pass + + if self.is_closed: self.writer.close() if hasattr(self.writer, "wait_closed"): await self.writer.wait_closed() + + if self.on_release is not None: + await self.on_release(self) diff --git a/httpcore/pool.py b/httpcore/pool.py index 6446c39a..6b4d328a 100644 --- a/httpcore/pool.py +++ b/httpcore/pool.py @@ -1,4 +1,5 @@ import asyncio +import functools import os import ssl import typing @@ -16,6 +17,8 @@ from .config import ( from .connections import Connection from .datastructures import URL, Request, Response +ConnectionKey = typing.Tuple[str, str, int] # (scheme, host, port) + class ConnectionPool: def __init__( @@ -29,6 +32,11 @@ class ConnectionPool: self.timeout = timeout self.limits = limits self.is_closed = False + self.num_active_connections = 0 + self.num_keepalive_connections = 0 + self._connections = ( + {} + ) # type: typing.Dict[ConnectionKey, typing.List[Connection]] async def request( self, @@ -52,15 +60,38 @@ class ConnectionPool: return response async def acquire_connection( - self, url: URL, *, ssl: typing.Union[bool, ssl.SSLContext] = False + self, url: URL, *, ssl: typing.Optional[ssl.SSLContext] = None ) -> Connection: - connection = Connection(timeout=self.timeout) - await connection.open(url.hostname, url.port, ssl=ssl) + key = (url.scheme, url.hostname, url.port) + try: + connection = self._connections[key].pop() + if not self._connections[key]: + del self._connections[key] + self.num_keepalive_connections -= 1 + self.num_active_connections += 1 + + except (KeyError, IndexError): + release = functools.partial(self.release_connection, key=key) + connection = Connection(timeout=self.timeout, on_release=release) + self.num_active_connections += 1 + await connection.open(url.hostname, url.port, ssl=ssl) + return connection - async def get_ssl_context(self, url: URL) -> typing.Union[bool, ssl.SSLContext]: + async def release_connection( + self, connection: Connection, key: ConnectionKey + ) -> None: + self.num_active_connections -= 1 + if not connection.is_closed: + self.num_keepalive_connections += 1 + try: + self._connections[key].append(connection) + except KeyError: + self._connections[key] = [connection] + + async def get_ssl_context(self, url: URL) -> typing.Optional[ssl.SSLContext]: if not url.is_secure: - return False + return None if not hasattr(self, "ssl_context"): if not self.ssl_config.verify: diff --git a/tests/conftest.py b/tests/conftest.py index 234cf43b..efb79df1 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -27,4 +27,6 @@ async def server(): await asyncio.sleep(0.0001) yield server finally: - task.cancel() + server.should_exit = True + server.force_exit = True + await task diff --git a/tests/test_pool.py b/tests/test_pool.py new file mode 100644 index 00000000..444d51c8 --- /dev/null +++ b/tests/test_pool.py @@ -0,0 +1,109 @@ +import pytest + +import httpcore + + +@pytest.mark.asyncio +async def test_keepalive_connections(server): + """ + Connections should default to staying in a keep-alive state. + """ + async with httpcore.ConnectionPool() as http: + response = await http.request("GET", "http://127.0.0.1:8000/") + assert http.num_active_connections == 0 + assert http.num_keepalive_connections == 1 + + response = await http.request("GET", "http://127.0.0.1:8000/") + assert http.num_active_connections == 0 + assert http.num_keepalive_connections == 1 + + +@pytest.mark.asyncio +async def test_differing_connection_keys(server): + """ + Connnections to differing connection keys should result in multiple connections. + """ + async with httpcore.ConnectionPool() as http: + response = await http.request("GET", "http://127.0.0.1:8000/") + assert http.num_active_connections == 0 + assert http.num_keepalive_connections == 1 + + response = await http.request("GET", "http://localhost:8000/") + assert http.num_active_connections == 0 + assert http.num_keepalive_connections == 2 + + +@pytest.mark.asyncio +async def test_streaming_response_holds_connection(server): + """ + A streaming request should hold the connection open until the response is read. + """ + async with httpcore.ConnectionPool() as http: + response = await http.request("GET", "http://127.0.0.1:8000/", stream=True) + assert http.num_active_connections == 1 + assert http.num_keepalive_connections == 0 + + await response.read() + + assert http.num_active_connections == 0 + assert http.num_keepalive_connections == 1 + + +@pytest.mark.asyncio +async def test_multiple_concurrent_connections(server): + """ + Multiple conncurrent requests should open multiple conncurrent connections. + """ + async with httpcore.ConnectionPool() as http: + response_a = await http.request("GET", "http://127.0.0.1:8000/", stream=True) + assert http.num_active_connections == 1 + assert http.num_keepalive_connections == 0 + + response_b = await http.request("GET", "http://127.0.0.1:8000/", stream=True) + assert http.num_active_connections == 2 + assert http.num_keepalive_connections == 0 + + await response_b.read() + assert http.num_active_connections == 1 + assert http.num_keepalive_connections == 1 + + await response_a.read() + assert http.num_active_connections == 0 + assert http.num_keepalive_connections == 2 + + +@pytest.mark.asyncio +async def test_close_connections(server): + """ + Using a `Connection: close` header should close the connection. + """ + headers = [(b"connection", b"close")] + async with httpcore.ConnectionPool() as http: + response = await http.request("GET", "http://127.0.0.1:8000/", headers=headers) + assert http.num_active_connections == 0 + assert http.num_keepalive_connections == 0 + + +@pytest.mark.asyncio +async def test_standard_response_close(server): + """ + A standard close should keep the connection open. + """ + async with httpcore.ConnectionPool() as http: + response = await http.request("GET", "http://127.0.0.1:8000/", stream=True) + await response.read() + await response.close() + assert http.num_active_connections == 0 + assert http.num_keepalive_connections == 1 + + +@pytest.mark.asyncio +async def test_premature_response_close(server): + """ + A premature close should close the connection. + """ + async with httpcore.ConnectionPool() as http: + response = await http.request("GET", "http://127.0.0.1:8000/", stream=True) + await response.close() + assert http.num_active_connections == 0 + assert http.num_keepalive_connections == 0