VerifyTypes,
)
from ..decoders import ACCEPT_ENCODING
-from ..exceptions import PoolTimeout
+from ..exceptions import NotConnected, PoolTimeout
from ..interfaces import AsyncDispatcher, ConcurrencyBackend
from ..models import AsyncRequest, AsyncResponse, Origin
from .connection import HTTPConnection
cert: CertTypes = None,
timeout: TimeoutTypes = None,
) -> AsyncResponse:
- connection = await self.acquire_connection(request.url.origin)
- try:
- response = await connection.send(
- request, verify=verify, cert=cert, timeout=timeout
+ allow_connection_reuse = True
+ connection = None
+ while connection is None:
+ connection = await self.acquire_connection(
+ origin=request.url.origin, allow_connection_reuse=allow_connection_reuse
)
- except BaseException as exc:
- self.active_connections.remove(connection)
- self.max_connections.release()
- raise exc
+ try:
+ response = await connection.send(
+ request, verify=verify, cert=cert, timeout=timeout
+ )
+ except BaseException as exc:
+ self.active_connections.remove(connection)
+ self.max_connections.release()
+ if isinstance(exc, NotConnected) and allow_connection_reuse:
+ connection = None
+ allow_connection_reuse = False
+ else:
+ raise exc
+
return response
- async def acquire_connection(self, origin: Origin) -> HTTPConnection:
- connection = self.active_connections.pop_by_origin(origin, http2_only=True)
- if connection is None:
- connection = self.keepalive_connections.pop_by_origin(origin)
+ async def acquire_connection(
+ self, origin: Origin, allow_connection_reuse: bool = True
+ ) -> HTTPConnection:
+ connection = None
+ if allow_connection_reuse:
+ connection = self.active_connections.pop_by_origin(origin, http2_only=True)
+ if connection is None:
+ connection = self.keepalive_connections.pop_by_origin(origin)
if connection is None:
await self.max_connections.acquire()
from ..concurrency import TimeoutFlag
from ..config import DEFAULT_TIMEOUT_CONFIG, TimeoutConfig, TimeoutTypes
-from ..exceptions import ConnectTimeout, ReadTimeout
+from ..exceptions import ConnectTimeout, NotConnected, ReadTimeout
from ..interfaces import BaseReader, BaseWriter, ConcurrencyBackend
from ..models import AsyncRequest, AsyncResponse
) -> AsyncResponse:
timeout = None if timeout is None else TimeoutConfig(timeout)
- await self._send_request(request, timeout)
+ try:
+ await self._send_request(request, timeout)
+ except ConnectionResetError: # pragma: nocover
+ # We're currently testing this case in HTTP/2.
+ # Really we should test it here too, but this'll do in the meantime.
+ raise NotConnected() from None
+
task, args = self._send_request_data, [request.stream(), timeout]
async with self.backend.background_manager(task, args=args):
http_version, status_code, headers = await self._receive_response(timeout)
from ..concurrency import TimeoutFlag
from ..config import DEFAULT_TIMEOUT_CONFIG, TimeoutConfig, TimeoutTypes
-from ..exceptions import ConnectTimeout, ReadTimeout
+from ..exceptions import ConnectTimeout, NotConnected, ReadTimeout
from ..interfaces import BaseReader, BaseWriter, ConcurrencyBackend
from ..models import AsyncRequest, AsyncResponse
if not self.initialized:
self.initiate_connection()
- stream_id = await self.send_headers(request, timeout)
+ try:
+ stream_id = await self.send_headers(request, timeout)
+ except ConnectionResetError:
+ raise NotConnected() from None
+
self.events[stream_id] = []
self.timeout_flags[stream_id] = TimeoutFlag()
# HTTP exceptions...
+class NotConnected(Exception):
+ """
+ A connection was lost at the point of starting a request,
+ prior to any writes succeeding.
+ """
+
+
class HttpError(Exception):
"""
- An Http error occurred.
+ An HTTP error occurred.
"""
"""
Return package version as listed in `__version__` in `init.py`.
"""
- with open(os.path.join(package, "__init__.py")) as f:
+ with open(os.path.join(package, "__version__.py")) as f:
return re.search("__version__ = ['\"]([^'\"]+)['\"]", f.read()).group(1)
assert response_3.status_code == 200
assert json.loads(response_3.content) == {"method": "GET", "path": "/3", "body": ""}
+
+
+def test_http2_reconnect():
+ """
+ If a connection has been dropped between requests, then we should
+ be seemlessly reconnected.
+ """
+ backend = MockHTTP2Backend(app=app)
+
+ with Client(backend=backend) as client:
+ response_1 = client.get("http://example.org/1")
+ backend.server.raise_disconnect = True
+ response_2 = client.get("http://example.org/2")
+
+ assert response_1.status_code == 200
+ assert json.loads(response_1.content) == {"method": "GET", "path": "/1", "body": ""}
+
+ assert response_2.status_code == 200
+ assert json.loads(response_2.content) == {"method": "GET", "path": "/2", "body": ""}
class MockHTTP2Backend(AsyncioBackend):
def __init__(self, app):
self.app = app
+ self.server = None
async def connect(
self,
ssl_context: typing.Optional[ssl.SSLContext],
timeout: TimeoutConfig,
) -> typing.Tuple[BaseReader, BaseWriter, Protocol]:
- server = MockHTTP2Server(self.app)
- return (server, server, Protocol.HTTP_2)
+ self.server = MockHTTP2Server(self.app)
+ return (self.server, self.server, Protocol.HTTP_2)
class MockHTTP2Server(BaseReader, BaseWriter):
self.app = app
self.buffer = b""
self.requests = {}
+ self.raise_disconnect = False
# BaseReader interface
# BaseWriter interface
def write_no_block(self, data: bytes) -> None:
+ if self.raise_disconnect:
+ self.raise_disconnect = False
+ raise ConnectionResetError()
events = self.conn.receive_data(data)
self.buffer += self.conn.data_to_send()
for event in events: