From: Tom Christie Date: Thu, 25 Jul 2019 21:52:41 +0000 (+0100) Subject: Check disconnections on connection reacquiry (#145) X-Git-Tag: 0.6.8~1 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=ec365c0e8e0c81279dc7a482e038aa5b547261e8;p=thirdparty%2Fhttpx.git Check disconnections on connection reacquiry (#145) * Detect EOF signaling remote server closed connection Raise ConnectionClosedByRemote and handle on `send` * Fix linting * Use existing NotConnected exception * Add `Reader.is_connection_dropped` method * Check connection before sending h11 events as well * Add test covering connection lost before reading response content * Check for connection closed on acquiring it from the pool * Clean up ConnectionPool logic around reaquiry of connections --- diff --git a/httpx/concurrency.py b/httpx/concurrency.py index b8246b53..2ee45a85 100644 --- a/httpx/concurrency.py +++ b/httpx/concurrency.py @@ -108,6 +108,9 @@ class Reader(BaseReader): return data + def is_connection_dropped(self) -> bool: + return self.stream_reader.at_eof() + class Writer(BaseWriter): def __init__(self, stream_writer: asyncio.StreamWriter, timeout: TimeoutConfig): diff --git a/httpx/dispatch/connection.py b/httpx/dispatch/connection.py index b1400afd..4a303f27 100644 --- a/httpx/dispatch/connection.py +++ b/httpx/dispatch/connection.py @@ -102,3 +102,10 @@ class HTTPConnection(AsyncDispatcher): else: assert self.h11_connection is not None return self.h11_connection.is_closed + + def is_connection_dropped(self) -> bool: + if self.h2_connection is not None: + return self.h2_connection.is_connection_dropped() + else: + assert self.h11_connection is not None + return self.h11_connection.is_connection_dropped() diff --git a/httpx/dispatch/connection_pool.py b/httpx/dispatch/connection_pool.py index 0b827c12..b6c0457c 100644 --- a/httpx/dispatch/connection_pool.py +++ b/httpx/dispatch/connection_pool.py @@ -9,7 +9,6 @@ from ..config import ( TimeoutTypes, VerifyTypes, ) -from ..exceptions import NotConnected from ..interfaces import AsyncDispatcher, ConcurrencyBackend from ..models import AsyncRequest, AsyncResponse, Origin from .connection import HTTPConnection @@ -108,35 +107,25 @@ class ConnectionPool(AsyncDispatcher): cert: CertTypes = None, timeout: TimeoutTypes = None, ) -> AsyncResponse: - allow_connection_reuse = True - connection = None - while connection is None: - connection = await self.acquire_connection( - origin=request.url.origin, allow_connection_reuse=allow_connection_reuse + connection = await self.acquire_connection(origin=request.url.origin) + try: + response = await connection.send( + request, verify=verify, cert=cert, timeout=timeout ) - try: - response = await connection.send( - request, verify=verify, cert=cert, timeout=timeout - ) - except BaseException as exc: - self.active_connections.remove(connection) - self.max_connections.release() - if isinstance(exc, NotConnected) and allow_connection_reuse: - connection = None - allow_connection_reuse = False - else: - raise exc + except BaseException as exc: + self.active_connections.remove(connection) + self.max_connections.release() + raise exc return response - async def acquire_connection( - self, origin: Origin, allow_connection_reuse: bool = True - ) -> HTTPConnection: - connection = None - if allow_connection_reuse: - connection = self.active_connections.pop_by_origin(origin, http2_only=True) - if connection is None: - connection = self.keepalive_connections.pop_by_origin(origin) + async def acquire_connection(self, origin: Origin) -> HTTPConnection: + connection = self.active_connections.pop_by_origin(origin, http2_only=True) + if connection is None: + connection = self.keepalive_connections.pop_by_origin(origin) + + if connection is not None and connection.is_connection_dropped(): + connection = None if connection is None: await self.max_connections.acquire() diff --git a/httpx/dispatch/http11.py b/httpx/dispatch/http11.py index c2395356..dd0e24ad 100644 --- a/httpx/dispatch/http11.py +++ b/httpx/dispatch/http11.py @@ -4,7 +4,6 @@ import h11 from ..concurrency import TimeoutFlag from ..config import TimeoutConfig, TimeoutTypes -from ..exceptions import NotConnected from ..interfaces import BaseReader, BaseWriter, ConcurrencyBackend from ..models import AsyncRequest, AsyncResponse @@ -46,12 +45,7 @@ class HTTP11Connection: ) -> AsyncResponse: timeout = None if timeout is None else TimeoutConfig(timeout) - try: - await self._send_request(request, timeout) - except ConnectionResetError: # pragma: nocover - # We're currently testing this case in HTTP/2. - # Really we should test it here too, but this'll do in the meantime. - raise NotConnected() from None + await self._send_request(request, timeout) task, args = self._send_request_data, [request.stream(), timeout] async with self.backend.background_manager(task, args=args): @@ -188,3 +182,6 @@ class HTTP11Connection: @property def is_closed(self) -> bool: return self.h11_state.our_state in (h11.CLOSED, h11.ERROR) + + def is_connection_dropped(self) -> bool: + return self.reader.is_connection_dropped() diff --git a/httpx/dispatch/http2.py b/httpx/dispatch/http2.py index 35d487ad..331f82df 100644 --- a/httpx/dispatch/http2.py +++ b/httpx/dispatch/http2.py @@ -6,7 +6,6 @@ import h2.events from ..concurrency import TimeoutFlag from ..config import TimeoutConfig, TimeoutTypes -from ..exceptions import NotConnected from ..interfaces import BaseReader, BaseWriter, ConcurrencyBackend from ..models import AsyncRequest, AsyncResponse @@ -39,10 +38,7 @@ class HTTP2Connection: if not self.initialized: self.initiate_connection() - try: - stream_id = await self.send_headers(request, timeout) - except ConnectionResetError: - raise NotConnected() from None + stream_id = await self.send_headers(request, timeout) self.events[stream_id] = [] self.timeout_flags[stream_id] = TimeoutFlag() @@ -176,3 +172,6 @@ class HTTP2Connection: @property def is_closed(self) -> bool: return False + + def is_connection_dropped(self) -> bool: + return self.reader.is_connection_dropped() diff --git a/httpx/exceptions.py b/httpx/exceptions.py index 3305e2d7..19af3e6b 100644 --- a/httpx/exceptions.py +++ b/httpx/exceptions.py @@ -34,13 +34,6 @@ class PoolTimeout(Timeout): # HTTP exceptions... -class NotConnected(Exception): - """ - A connection was lost at the point of starting a request, - prior to any writes succeeding. - """ - - class HttpError(Exception): """ An HTTP error occurred. diff --git a/httpx/interfaces.py b/httpx/interfaces.py index 02d11ce5..f058edeb 100644 --- a/httpx/interfaces.py +++ b/httpx/interfaces.py @@ -130,6 +130,9 @@ class BaseReader: ) -> bytes: raise NotImplementedError() # pragma: no cover + def is_connection_dropped(self) -> bool: + raise NotImplementedError() # pragma: no cover + class BaseWriter: """ diff --git a/tests/dispatch/test_connection_pools.py b/tests/dispatch/test_connection_pools.py index 5ef884ee..14fd40c6 100644 --- a/tests/dispatch/test_connection_pools.py +++ b/tests/dispatch/test_connection_pools.py @@ -131,3 +131,39 @@ async def test_premature_response_close(server): await response.close() assert len(http.active_connections) == 0 assert len(http.keepalive_connections) == 0 + + +@pytest.mark.asyncio +async def test_keepalive_connection_closed_by_server_is_reestablished(server): + """ + Upon keep-alive connection closed by remote a new connection should be reestablished. + """ + async with httpx.ConnectionPool() as http: + response = await http.request("GET", "http://127.0.0.1:8000/") + await response.read() + + await server.shutdown() # shutdown the server to close the keep-alive connection + await server.startup() + + response = await http.request("GET", "http://127.0.0.1:8000/") + await response.read() + assert len(http.active_connections) == 0 + assert len(http.keepalive_connections) == 1 + + +@pytest.mark.asyncio +async def test_keepalive_http2_connection_closed_by_server_is_reestablished(server): + """ + Upon keep-alive connection closed by remote a new connection should be reestablished. + """ + async with httpx.ConnectionPool() as http: + response = await http.request("GET", "http://127.0.0.1:8000/") + await response.read() + + await server.shutdown() # shutdown the server to close the keep-alive connection + await server.startup() + + response = await http.request("GET", "http://127.0.0.1:8000/") + await response.read() + assert len(http.active_connections) == 0 + assert len(http.keepalive_connections) == 1 diff --git a/tests/dispatch/test_http2.py b/tests/dispatch/test_http2.py index b1945242..99caab6c 100644 --- a/tests/dispatch/test_http2.py +++ b/tests/dispatch/test_http2.py @@ -68,7 +68,7 @@ def test_http2_reconnect(): with Client(backend=backend) as client: response_1 = client.get("http://example.org/1") - backend.server.raise_disconnect = True + backend.server.close_connection = True response_2 = client.get("http://example.org/2") assert response_1.status_code == 200 diff --git a/tests/dispatch/utils.py b/tests/dispatch/utils.py index a1e02a4c..8a62554b 100644 --- a/tests/dispatch/utils.py +++ b/tests/dispatch/utils.py @@ -43,7 +43,7 @@ class MockHTTP2Server(BaseReader, BaseWriter): self.app = app self.buffer = b"" self.requests = {} - self.raise_disconnect = False + self.close_connection = False # BaseReader interface @@ -55,9 +55,6 @@ class MockHTTP2Server(BaseReader, BaseWriter): # BaseWriter interface def write_no_block(self, data: bytes) -> None: - if self.raise_disconnect: - self.raise_disconnect = False - raise ConnectionResetError() events = self.conn.receive_data(data) self.buffer += self.conn.data_to_send() for event in events: @@ -74,6 +71,9 @@ class MockHTTP2Server(BaseReader, BaseWriter): async def close(self) -> None: pass + def is_connection_dropped(self) -> bool: + return self.close_connection + # Server implementation def request_received(self, headers, stream_id): diff --git a/tests/test_multipart.py b/tests/test_multipart.py index 9a8f12c7..354bccee 100644 --- a/tests/test_multipart.py +++ b/tests/test_multipart.py @@ -10,11 +10,11 @@ from httpx import ( CertTypes, Client, Dispatcher, - multipart, Request, Response, TimeoutTypes, VerifyTypes, + multipart, )