]> git.ipfire.org Git - thirdparty/httpx.git/commitdiff
Check for connection aliveness on pool re-acquiry (#111)
authorTom Christie <tom@tomchristie.com>
Mon, 8 Jul 2019 14:57:29 +0000 (15:57 +0100)
committerGitHub <noreply@github.com>
Mon, 8 Jul 2019 14:57:29 +0000 (15:57 +0100)
* Check for connection aliveness on pool re-acquiry

* Test for connection re-acquiry with HTTP/2

* nocover on HTTP/1.1 ConnectionResetError, since we're testing the equivelent on HTTP/2

* Fix setup.py to read version from __version__

http3/dispatch/connection_pool.py
http3/dispatch/http11.py
http3/dispatch/http2.py
http3/exceptions.py
setup.py
tests/dispatch/test_http2.py
tests/dispatch/utils.py

index c84117ca45bb6652437a7a7d9c083b5b0362e2e1..d9794db49bbefda140bef4be7dfaccfcbde2b16f 100644 (file)
@@ -11,7 +11,7 @@ from ..config import (
     VerifyTypes,
 )
 from ..decoders import ACCEPT_ENCODING
-from ..exceptions import PoolTimeout
+from ..exceptions import NotConnected, PoolTimeout
 from ..interfaces import AsyncDispatcher, ConcurrencyBackend
 from ..models import AsyncRequest, AsyncResponse, Origin
 from .connection import HTTPConnection
@@ -110,21 +110,35 @@ class ConnectionPool(AsyncDispatcher):
         cert: CertTypes = None,
         timeout: TimeoutTypes = None,
     ) -> AsyncResponse:
-        connection = await self.acquire_connection(request.url.origin)
-        try:
-            response = await connection.send(
-                request, verify=verify, cert=cert, timeout=timeout
+        allow_connection_reuse = True
+        connection = None
+        while connection is None:
+            connection = await self.acquire_connection(
+                origin=request.url.origin, allow_connection_reuse=allow_connection_reuse
             )
-        except BaseException as exc:
-            self.active_connections.remove(connection)
-            self.max_connections.release()
-            raise exc
+            try:
+                response = await connection.send(
+                    request, verify=verify, cert=cert, timeout=timeout
+                )
+            except BaseException as exc:
+                self.active_connections.remove(connection)
+                self.max_connections.release()
+                if isinstance(exc, NotConnected) and allow_connection_reuse:
+                    connection = None
+                    allow_connection_reuse = False
+                else:
+                    raise exc
+
         return response
 
-    async def acquire_connection(self, origin: Origin) -> HTTPConnection:
-        connection = self.active_connections.pop_by_origin(origin, http2_only=True)
-        if connection is None:
-            connection = self.keepalive_connections.pop_by_origin(origin)
+    async def acquire_connection(
+        self, origin: Origin, allow_connection_reuse: bool = True
+    ) -> HTTPConnection:
+        connection = None
+        if allow_connection_reuse:
+            connection = self.active_connections.pop_by_origin(origin, http2_only=True)
+            if connection is None:
+                connection = self.keepalive_connections.pop_by_origin(origin)
 
         if connection is None:
             await self.max_connections.acquire()
index b15e15884ddd2c77e59ea71a98941e904c7e7cad..e4124412a6647f797df102a9fdeb651ab017a3fc 100644 (file)
@@ -4,7 +4,7 @@ import h11
 
 from ..concurrency import TimeoutFlag
 from ..config import DEFAULT_TIMEOUT_CONFIG, TimeoutConfig, TimeoutTypes
-from ..exceptions import ConnectTimeout, ReadTimeout
+from ..exceptions import ConnectTimeout, NotConnected, ReadTimeout
 from ..interfaces import BaseReader, BaseWriter, ConcurrencyBackend
 from ..models import AsyncRequest, AsyncResponse
 
@@ -46,7 +46,13 @@ class HTTP11Connection:
     ) -> AsyncResponse:
         timeout = None if timeout is None else TimeoutConfig(timeout)
 
-        await self._send_request(request, timeout)
+        try:
+            await self._send_request(request, timeout)
+        except ConnectionResetError:  # pragma: nocover
+            # We're currently testing this case in HTTP/2.
+            # Really we should test it here too, but this'll do in the meantime.
+            raise NotConnected() from None
+
         task, args = self._send_request_data, [request.stream(), timeout]
         async with self.backend.background_manager(task, args=args):
             http_version, status_code, headers = await self._receive_response(timeout)
index 56c7728f63e3a41c7a9d4cdb25f4f02b532bdadc..3dd778d5a54d9b00a5fd7ba71218082e1838977c 100644 (file)
@@ -6,7 +6,7 @@ import h2.events
 
 from ..concurrency import TimeoutFlag
 from ..config import DEFAULT_TIMEOUT_CONFIG, TimeoutConfig, TimeoutTypes
-from ..exceptions import ConnectTimeout, ReadTimeout
+from ..exceptions import ConnectTimeout, NotConnected, ReadTimeout
 from ..interfaces import BaseReader, BaseWriter, ConcurrencyBackend
 from ..models import AsyncRequest, AsyncResponse
 
@@ -39,7 +39,11 @@ class HTTP2Connection:
         if not self.initialized:
             self.initiate_connection()
 
-        stream_id = await self.send_headers(request, timeout)
+        try:
+            stream_id = await self.send_headers(request, timeout)
+        except ConnectionResetError:
+            raise NotConnected() from None
+
         self.events[stream_id] = []
         self.timeout_flags[stream_id] = TimeoutFlag()
 
index ed3117309f5a65e16dfeb552af50ef538e6b3c83..3305e2d7f7d7b47974a314c06d509097663fadde 100644 (file)
@@ -34,9 +34,16 @@ class PoolTimeout(Timeout):
 # HTTP exceptions...
 
 
+class NotConnected(Exception):
+    """
+    A connection was lost at the point of starting a request,
+    prior to any writes succeeding.
+    """
+
+
 class HttpError(Exception):
     """
-    An Http error occurred.
+    An HTTP error occurred.
     """
 
 
index 69fcde0a6ff71ba5801489ace87bbfc308cdd54b..a6a33d0dc5ecc16e0f27404225b74d78049a1d4d 100644 (file)
--- a/setup.py
+++ b/setup.py
@@ -11,7 +11,7 @@ def get_version(package):
     """
     Return package version as listed in `__version__` in `init.py`.
     """
-    with open(os.path.join(package, "__init__.py")) as f:
+    with open(os.path.join(package, "__version__.py")) as f:
         return re.search("__version__ = ['\"]([^'\"]+)['\"]", f.read()).group(1)
 
 
index b87e528a3a5a370d6415341d3037d25c3823b2d3..8da5b0d5cdbe24f1d8b6a221092f1d0968786bfe 100644 (file)
@@ -59,3 +59,22 @@ def test_http2_multiple_requests():
 
     assert response_3.status_code == 200
     assert json.loads(response_3.content) == {"method": "GET", "path": "/3", "body": ""}
+
+
+def test_http2_reconnect():
+    """
+    If a connection has been dropped between requests, then we should
+    be seemlessly reconnected.
+    """
+    backend = MockHTTP2Backend(app=app)
+
+    with Client(backend=backend) as client:
+        response_1 = client.get("http://example.org/1")
+        backend.server.raise_disconnect = True
+        response_2 = client.get("http://example.org/2")
+
+    assert response_1.status_code == 200
+    assert json.loads(response_1.content) == {"method": "GET", "path": "/1", "body": ""}
+
+    assert response_2.status_code == 200
+    assert json.loads(response_2.content) == {"method": "GET", "path": "/2", "body": ""}
index 4764f3186a2bc73f3d0373c7a02b7d08d364945c..85c2674bf6773f14be4469170dcbf89073ec010f 100644 (file)
@@ -19,6 +19,7 @@ from http3 import (
 class MockHTTP2Backend(AsyncioBackend):
     def __init__(self, app):
         self.app = app
+        self.server = None
 
     async def connect(
         self,
@@ -27,8 +28,8 @@ class MockHTTP2Backend(AsyncioBackend):
         ssl_context: typing.Optional[ssl.SSLContext],
         timeout: TimeoutConfig,
     ) -> typing.Tuple[BaseReader, BaseWriter, Protocol]:
-        server = MockHTTP2Server(self.app)
-        return (server, server, Protocol.HTTP_2)
+        self.server = MockHTTP2Server(self.app)
+        return (self.server, self.server, Protocol.HTTP_2)
 
 
 class MockHTTP2Server(BaseReader, BaseWriter):
@@ -42,6 +43,7 @@ class MockHTTP2Server(BaseReader, BaseWriter):
         self.app = app
         self.buffer = b""
         self.requests = {}
+        self.raise_disconnect = False
 
     # BaseReader interface
 
@@ -53,6 +55,9 @@ class MockHTTP2Server(BaseReader, BaseWriter):
     # BaseWriter interface
 
     def write_no_block(self, data: bytes) -> None:
+        if self.raise_disconnect:
+            self.raise_disconnect = False
+            raise ConnectionResetError()
         events = self.conn.receive_data(data)
         self.buffer += self.conn.data_to_send()
         for event in events: