]> git.ipfire.org Git - thirdparty/httpx.git/commitdiff
Check disconnections on connection reacquiry (#145)
authorTom Christie <tom@tomchristie.com>
Thu, 25 Jul 2019 21:52:41 +0000 (22:52 +0100)
committerGitHub <noreply@github.com>
Thu, 25 Jul 2019 21:52:41 +0000 (22:52 +0100)
* Detect EOF signaling remote server closed connection

Raise ConnectionClosedByRemote and handle on `send`

* Fix linting

* Use existing NotConnected exception

* Add `Reader.is_connection_dropped` method

* Check connection before sending h11 events as well

* Add test covering connection lost before reading response content

* Check for connection closed on acquiring it from the pool

* Clean up ConnectionPool logic around reaquiry of connections

httpx/concurrency.py
httpx/dispatch/connection.py
httpx/dispatch/connection_pool.py
httpx/dispatch/http11.py
httpx/dispatch/http2.py
httpx/exceptions.py
httpx/interfaces.py
tests/dispatch/test_connection_pools.py
tests/dispatch/test_http2.py
tests/dispatch/utils.py
tests/test_multipart.py

index b8246b535036639275ae79ca2dd3b4384960d610..2ee45a85036ed20eac29e8f72def77045c99f3f1 100644 (file)
@@ -108,6 +108,9 @@ class Reader(BaseReader):
 
         return data
 
+    def is_connection_dropped(self) -> bool:
+        return self.stream_reader.at_eof()
+
 
 class Writer(BaseWriter):
     def __init__(self, stream_writer: asyncio.StreamWriter, timeout: TimeoutConfig):
index b1400afdf4a7640beaf9b3c5dec0f6c1d1c3ef26..4a303f27793b7197dc4fa25ba942c890140a108e 100644 (file)
@@ -102,3 +102,10 @@ class HTTPConnection(AsyncDispatcher):
         else:
             assert self.h11_connection is not None
             return self.h11_connection.is_closed
+
+    def is_connection_dropped(self) -> bool:
+        if self.h2_connection is not None:
+            return self.h2_connection.is_connection_dropped()
+        else:
+            assert self.h11_connection is not None
+            return self.h11_connection.is_connection_dropped()
index 0b827c128f0cedbddcdbe49f8124fe8d5ae6fca5..b6c0457c5f1dc0f3a9f502eb4ee5cbd72a6078ec 100644 (file)
@@ -9,7 +9,6 @@ from ..config import (
     TimeoutTypes,
     VerifyTypes,
 )
-from ..exceptions import NotConnected
 from ..interfaces import AsyncDispatcher, ConcurrencyBackend
 from ..models import AsyncRequest, AsyncResponse, Origin
 from .connection import HTTPConnection
@@ -108,35 +107,25 @@ class ConnectionPool(AsyncDispatcher):
         cert: CertTypes = None,
         timeout: TimeoutTypes = None,
     ) -> AsyncResponse:
-        allow_connection_reuse = True
-        connection = None
-        while connection is None:
-            connection = await self.acquire_connection(
-                origin=request.url.origin, allow_connection_reuse=allow_connection_reuse
+        connection = await self.acquire_connection(origin=request.url.origin)
+        try:
+            response = await connection.send(
+                request, verify=verify, cert=cert, timeout=timeout
             )
-            try:
-                response = await connection.send(
-                    request, verify=verify, cert=cert, timeout=timeout
-                )
-            except BaseException as exc:
-                self.active_connections.remove(connection)
-                self.max_connections.release()
-                if isinstance(exc, NotConnected) and allow_connection_reuse:
-                    connection = None
-                    allow_connection_reuse = False
-                else:
-                    raise exc
+        except BaseException as exc:
+            self.active_connections.remove(connection)
+            self.max_connections.release()
+            raise exc
 
         return response
 
-    async def acquire_connection(
-        self, origin: Origin, allow_connection_reuse: bool = True
-    ) -> HTTPConnection:
-        connection = None
-        if allow_connection_reuse:
-            connection = self.active_connections.pop_by_origin(origin, http2_only=True)
-            if connection is None:
-                connection = self.keepalive_connections.pop_by_origin(origin)
+    async def acquire_connection(self, origin: Origin) -> HTTPConnection:
+        connection = self.active_connections.pop_by_origin(origin, http2_only=True)
+        if connection is None:
+            connection = self.keepalive_connections.pop_by_origin(origin)
+
+        if connection is not None and connection.is_connection_dropped():
+            connection = None
 
         if connection is None:
             await self.max_connections.acquire()
index c23953562fbc981b6f2cee3c1f4ea0bd9793d09e..dd0e24ad105cad88e1df767b3efe9595b957da3d 100644 (file)
@@ -4,7 +4,6 @@ import h11
 
 from ..concurrency import TimeoutFlag
 from ..config import TimeoutConfig, TimeoutTypes
-from ..exceptions import NotConnected
 from ..interfaces import BaseReader, BaseWriter, ConcurrencyBackend
 from ..models import AsyncRequest, AsyncResponse
 
@@ -46,12 +45,7 @@ class HTTP11Connection:
     ) -> AsyncResponse:
         timeout = None if timeout is None else TimeoutConfig(timeout)
 
-        try:
-            await self._send_request(request, timeout)
-        except ConnectionResetError:  # pragma: nocover
-            # We're currently testing this case in HTTP/2.
-            # Really we should test it here too, but this'll do in the meantime.
-            raise NotConnected() from None
+        await self._send_request(request, timeout)
 
         task, args = self._send_request_data, [request.stream(), timeout]
         async with self.backend.background_manager(task, args=args):
@@ -188,3 +182,6 @@ class HTTP11Connection:
     @property
     def is_closed(self) -> bool:
         return self.h11_state.our_state in (h11.CLOSED, h11.ERROR)
+
+    def is_connection_dropped(self) -> bool:
+        return self.reader.is_connection_dropped()
index 35d487ad6b056dd16d45a883639bc2ef63cc327d..331f82df3808d5fceaaf0bac02f206dd9563f514 100644 (file)
@@ -6,7 +6,6 @@ import h2.events
 
 from ..concurrency import TimeoutFlag
 from ..config import TimeoutConfig, TimeoutTypes
-from ..exceptions import NotConnected
 from ..interfaces import BaseReader, BaseWriter, ConcurrencyBackend
 from ..models import AsyncRequest, AsyncResponse
 
@@ -39,10 +38,7 @@ class HTTP2Connection:
         if not self.initialized:
             self.initiate_connection()
 
-        try:
-            stream_id = await self.send_headers(request, timeout)
-        except ConnectionResetError:
-            raise NotConnected() from None
+        stream_id = await self.send_headers(request, timeout)
 
         self.events[stream_id] = []
         self.timeout_flags[stream_id] = TimeoutFlag()
@@ -176,3 +172,6 @@ class HTTP2Connection:
     @property
     def is_closed(self) -> bool:
         return False
+
+    def is_connection_dropped(self) -> bool:
+        return self.reader.is_connection_dropped()
index 3305e2d7f7d7b47974a314c06d509097663fadde..19af3e6be46d0a6e807d95aefd1a8b26ba80c747 100644 (file)
@@ -34,13 +34,6 @@ class PoolTimeout(Timeout):
 # HTTP exceptions...
 
 
-class NotConnected(Exception):
-    """
-    A connection was lost at the point of starting a request,
-    prior to any writes succeeding.
-    """
-
-
 class HttpError(Exception):
     """
     An HTTP error occurred.
index 02d11ce5b56d7b485983cfd665c95d6c0d4a1378..f058edeb6dcb48c0228aee8b06fb928f1f7ccdf6 100644 (file)
@@ -130,6 +130,9 @@ class BaseReader:
     ) -> bytes:
         raise NotImplementedError()  # pragma: no cover
 
+    def is_connection_dropped(self) -> bool:
+        raise NotImplementedError()  # pragma: no cover
+
 
 class BaseWriter:
     """
index 5ef884ee58765334dddffec0f1a7ca9e3ef43fc7..14fd40c62e8fe57299a19421ebe69ef19446580e 100644 (file)
@@ -131,3 +131,39 @@ async def test_premature_response_close(server):
         await response.close()
         assert len(http.active_connections) == 0
         assert len(http.keepalive_connections) == 0
+
+
+@pytest.mark.asyncio
+async def test_keepalive_connection_closed_by_server_is_reestablished(server):
+    """
+    Upon keep-alive connection closed by remote a new connection should be reestablished.
+    """
+    async with httpx.ConnectionPool() as http:
+        response = await http.request("GET", "http://127.0.0.1:8000/")
+        await response.read()
+
+        await server.shutdown()  # shutdown the server to close the keep-alive connection
+        await server.startup()
+
+        response = await http.request("GET", "http://127.0.0.1:8000/")
+        await response.read()
+        assert len(http.active_connections) == 0
+        assert len(http.keepalive_connections) == 1
+
+
+@pytest.mark.asyncio
+async def test_keepalive_http2_connection_closed_by_server_is_reestablished(server):
+    """
+    Upon keep-alive connection closed by remote a new connection should be reestablished.
+    """
+    async with httpx.ConnectionPool() as http:
+        response = await http.request("GET", "http://127.0.0.1:8000/")
+        await response.read()
+
+        await server.shutdown()  # shutdown the server to close the keep-alive connection
+        await server.startup()
+
+        response = await http.request("GET", "http://127.0.0.1:8000/")
+        await response.read()
+        assert len(http.active_connections) == 0
+        assert len(http.keepalive_connections) == 1
index b19452427dda91ce6bd44cf861974f0fe4fa9a62..99caab6c8e6b6f5e6dcf9e72dc49348488349eac 100644 (file)
@@ -68,7 +68,7 @@ def test_http2_reconnect():
 
     with Client(backend=backend) as client:
         response_1 = client.get("http://example.org/1")
-        backend.server.raise_disconnect = True
+        backend.server.close_connection = True
         response_2 = client.get("http://example.org/2")
 
     assert response_1.status_code == 200
index a1e02a4c18f5fbdd0eddcb8a11db4ad107edfe98..8a62554b2583f672ec3c5c5cfca4c4cd1a809c2d 100644 (file)
@@ -43,7 +43,7 @@ class MockHTTP2Server(BaseReader, BaseWriter):
         self.app = app
         self.buffer = b""
         self.requests = {}
-        self.raise_disconnect = False
+        self.close_connection = False
 
     # BaseReader interface
 
@@ -55,9 +55,6 @@ class MockHTTP2Server(BaseReader, BaseWriter):
     # BaseWriter interface
 
     def write_no_block(self, data: bytes) -> None:
-        if self.raise_disconnect:
-            self.raise_disconnect = False
-            raise ConnectionResetError()
         events = self.conn.receive_data(data)
         self.buffer += self.conn.data_to_send()
         for event in events:
@@ -74,6 +71,9 @@ class MockHTTP2Server(BaseReader, BaseWriter):
     async def close(self) -> None:
         pass
 
+    def is_connection_dropped(self) -> bool:
+        return self.close_connection
+
     # Server implementation
 
     def request_received(self, headers, stream_id):
index 9a8f12c7c22f1c8ada03c2929eeea431af93976e..354bccee7361f9b078c9b01d42730102610ea898 100644 (file)
@@ -10,11 +10,11 @@ from httpx import (
     CertTypes,
     Client,
     Dispatcher,
-    multipart,
     Request,
     Response,
     TimeoutTypes,
     VerifyTypes,
+    multipart,
 )