From: Tom Christie Date: Mon, 8 Jul 2019 14:57:29 +0000 (+0100) Subject: Check for connection aliveness on pool re-acquiry (#111) X-Git-Tag: 0.6.7~1 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=025240c3b3e77f4295f8061431fa3b0665ad996c;p=thirdparty%2Fhttpx.git Check for connection aliveness on pool re-acquiry (#111) * Check for connection aliveness on pool re-acquiry * Test for connection re-acquiry with HTTP/2 * nocover on HTTP/1.1 ConnectionResetError, since we're testing the equivelent on HTTP/2 * Fix setup.py to read version from __version__ --- diff --git a/http3/dispatch/connection_pool.py b/http3/dispatch/connection_pool.py index c84117ca..d9794db4 100644 --- a/http3/dispatch/connection_pool.py +++ b/http3/dispatch/connection_pool.py @@ -11,7 +11,7 @@ from ..config import ( VerifyTypes, ) from ..decoders import ACCEPT_ENCODING -from ..exceptions import PoolTimeout +from ..exceptions import NotConnected, PoolTimeout from ..interfaces import AsyncDispatcher, ConcurrencyBackend from ..models import AsyncRequest, AsyncResponse, Origin from .connection import HTTPConnection @@ -110,21 +110,35 @@ class ConnectionPool(AsyncDispatcher): cert: CertTypes = None, timeout: TimeoutTypes = None, ) -> AsyncResponse: - connection = await self.acquire_connection(request.url.origin) - try: - response = await connection.send( - request, verify=verify, cert=cert, timeout=timeout + allow_connection_reuse = True + connection = None + while connection is None: + connection = await self.acquire_connection( + origin=request.url.origin, allow_connection_reuse=allow_connection_reuse ) - except BaseException as exc: - self.active_connections.remove(connection) - self.max_connections.release() - raise exc + try: + response = await connection.send( + request, verify=verify, cert=cert, timeout=timeout + ) + except BaseException as exc: + self.active_connections.remove(connection) + self.max_connections.release() + if isinstance(exc, NotConnected) and allow_connection_reuse: + connection = None + allow_connection_reuse = False + else: + raise exc + return response - async def acquire_connection(self, origin: Origin) -> HTTPConnection: - connection = self.active_connections.pop_by_origin(origin, http2_only=True) - if connection is None: - connection = self.keepalive_connections.pop_by_origin(origin) + async def acquire_connection( + self, origin: Origin, allow_connection_reuse: bool = True + ) -> HTTPConnection: + connection = None + if allow_connection_reuse: + connection = self.active_connections.pop_by_origin(origin, http2_only=True) + if connection is None: + connection = self.keepalive_connections.pop_by_origin(origin) if connection is None: await self.max_connections.acquire() diff --git a/http3/dispatch/http11.py b/http3/dispatch/http11.py index b15e1588..e4124412 100644 --- a/http3/dispatch/http11.py +++ b/http3/dispatch/http11.py @@ -4,7 +4,7 @@ import h11 from ..concurrency import TimeoutFlag from ..config import DEFAULT_TIMEOUT_CONFIG, TimeoutConfig, TimeoutTypes -from ..exceptions import ConnectTimeout, ReadTimeout +from ..exceptions import ConnectTimeout, NotConnected, ReadTimeout from ..interfaces import BaseReader, BaseWriter, ConcurrencyBackend from ..models import AsyncRequest, AsyncResponse @@ -46,7 +46,13 @@ class HTTP11Connection: ) -> AsyncResponse: timeout = None if timeout is None else TimeoutConfig(timeout) - await self._send_request(request, timeout) + try: + await self._send_request(request, timeout) + except ConnectionResetError: # pragma: nocover + # We're currently testing this case in HTTP/2. + # Really we should test it here too, but this'll do in the meantime. + raise NotConnected() from None + task, args = self._send_request_data, [request.stream(), timeout] async with self.backend.background_manager(task, args=args): http_version, status_code, headers = await self._receive_response(timeout) diff --git a/http3/dispatch/http2.py b/http3/dispatch/http2.py index 56c7728f..3dd778d5 100644 --- a/http3/dispatch/http2.py +++ b/http3/dispatch/http2.py @@ -6,7 +6,7 @@ import h2.events from ..concurrency import TimeoutFlag from ..config import DEFAULT_TIMEOUT_CONFIG, TimeoutConfig, TimeoutTypes -from ..exceptions import ConnectTimeout, ReadTimeout +from ..exceptions import ConnectTimeout, NotConnected, ReadTimeout from ..interfaces import BaseReader, BaseWriter, ConcurrencyBackend from ..models import AsyncRequest, AsyncResponse @@ -39,7 +39,11 @@ class HTTP2Connection: if not self.initialized: self.initiate_connection() - stream_id = await self.send_headers(request, timeout) + try: + stream_id = await self.send_headers(request, timeout) + except ConnectionResetError: + raise NotConnected() from None + self.events[stream_id] = [] self.timeout_flags[stream_id] = TimeoutFlag() diff --git a/http3/exceptions.py b/http3/exceptions.py index ed311730..3305e2d7 100644 --- a/http3/exceptions.py +++ b/http3/exceptions.py @@ -34,9 +34,16 @@ class PoolTimeout(Timeout): # HTTP exceptions... +class NotConnected(Exception): + """ + A connection was lost at the point of starting a request, + prior to any writes succeeding. + """ + + class HttpError(Exception): """ - An Http error occurred. + An HTTP error occurred. """ diff --git a/setup.py b/setup.py index 69fcde0a..a6a33d0d 100644 --- a/setup.py +++ b/setup.py @@ -11,7 +11,7 @@ def get_version(package): """ Return package version as listed in `__version__` in `init.py`. """ - with open(os.path.join(package, "__init__.py")) as f: + with open(os.path.join(package, "__version__.py")) as f: return re.search("__version__ = ['\"]([^'\"]+)['\"]", f.read()).group(1) diff --git a/tests/dispatch/test_http2.py b/tests/dispatch/test_http2.py index b87e528a..8da5b0d5 100644 --- a/tests/dispatch/test_http2.py +++ b/tests/dispatch/test_http2.py @@ -59,3 +59,22 @@ def test_http2_multiple_requests(): assert response_3.status_code == 200 assert json.loads(response_3.content) == {"method": "GET", "path": "/3", "body": ""} + + +def test_http2_reconnect(): + """ + If a connection has been dropped between requests, then we should + be seemlessly reconnected. + """ + backend = MockHTTP2Backend(app=app) + + with Client(backend=backend) as client: + response_1 = client.get("http://example.org/1") + backend.server.raise_disconnect = True + response_2 = client.get("http://example.org/2") + + assert response_1.status_code == 200 + assert json.loads(response_1.content) == {"method": "GET", "path": "/1", "body": ""} + + assert response_2.status_code == 200 + assert json.loads(response_2.content) == {"method": "GET", "path": "/2", "body": ""} diff --git a/tests/dispatch/utils.py b/tests/dispatch/utils.py index 4764f318..85c2674b 100644 --- a/tests/dispatch/utils.py +++ b/tests/dispatch/utils.py @@ -19,6 +19,7 @@ from http3 import ( class MockHTTP2Backend(AsyncioBackend): def __init__(self, app): self.app = app + self.server = None async def connect( self, @@ -27,8 +28,8 @@ class MockHTTP2Backend(AsyncioBackend): ssl_context: typing.Optional[ssl.SSLContext], timeout: TimeoutConfig, ) -> typing.Tuple[BaseReader, BaseWriter, Protocol]: - server = MockHTTP2Server(self.app) - return (server, server, Protocol.HTTP_2) + self.server = MockHTTP2Server(self.app) + return (self.server, self.server, Protocol.HTTP_2) class MockHTTP2Server(BaseReader, BaseWriter): @@ -42,6 +43,7 @@ class MockHTTP2Server(BaseReader, BaseWriter): self.app = app self.buffer = b"" self.requests = {} + self.raise_disconnect = False # BaseReader interface @@ -53,6 +55,9 @@ class MockHTTP2Server(BaseReader, BaseWriter): # BaseWriter interface def write_no_block(self, data: bytes) -> None: + if self.raise_disconnect: + self.raise_disconnect = False + raise ConnectionResetError() events = self.conn.receive_data(data) self.buffer += self.conn.data_to_send() for event in events: