return data
+ def is_connection_dropped(self) -> bool:
+ return self.stream_reader.at_eof()
+
class Writer(BaseWriter):
def __init__(self, stream_writer: asyncio.StreamWriter, timeout: TimeoutConfig):
else:
assert self.h11_connection is not None
return self.h11_connection.is_closed
+
+ def is_connection_dropped(self) -> bool:
+ if self.h2_connection is not None:
+ return self.h2_connection.is_connection_dropped()
+ else:
+ assert self.h11_connection is not None
+ return self.h11_connection.is_connection_dropped()
TimeoutTypes,
VerifyTypes,
)
-from ..exceptions import NotConnected
from ..interfaces import AsyncDispatcher, ConcurrencyBackend
from ..models import AsyncRequest, AsyncResponse, Origin
from .connection import HTTPConnection
cert: CertTypes = None,
timeout: TimeoutTypes = None,
) -> AsyncResponse:
- allow_connection_reuse = True
- connection = None
- while connection is None:
- connection = await self.acquire_connection(
- origin=request.url.origin, allow_connection_reuse=allow_connection_reuse
+ connection = await self.acquire_connection(origin=request.url.origin)
+ try:
+ response = await connection.send(
+ request, verify=verify, cert=cert, timeout=timeout
)
- try:
- response = await connection.send(
- request, verify=verify, cert=cert, timeout=timeout
- )
- except BaseException as exc:
- self.active_connections.remove(connection)
- self.max_connections.release()
- if isinstance(exc, NotConnected) and allow_connection_reuse:
- connection = None
- allow_connection_reuse = False
- else:
- raise exc
+ except BaseException as exc:
+ self.active_connections.remove(connection)
+ self.max_connections.release()
+ raise exc
return response
- async def acquire_connection(
- self, origin: Origin, allow_connection_reuse: bool = True
- ) -> HTTPConnection:
- connection = None
- if allow_connection_reuse:
- connection = self.active_connections.pop_by_origin(origin, http2_only=True)
- if connection is None:
- connection = self.keepalive_connections.pop_by_origin(origin)
+ async def acquire_connection(self, origin: Origin) -> HTTPConnection:
+ connection = self.active_connections.pop_by_origin(origin, http2_only=True)
+ if connection is None:
+ connection = self.keepalive_connections.pop_by_origin(origin)
+
+ if connection is not None and connection.is_connection_dropped():
+ connection = None
if connection is None:
await self.max_connections.acquire()
from ..concurrency import TimeoutFlag
from ..config import TimeoutConfig, TimeoutTypes
-from ..exceptions import NotConnected
from ..interfaces import BaseReader, BaseWriter, ConcurrencyBackend
from ..models import AsyncRequest, AsyncResponse
) -> AsyncResponse:
timeout = None if timeout is None else TimeoutConfig(timeout)
- try:
- await self._send_request(request, timeout)
- except ConnectionResetError: # pragma: nocover
- # We're currently testing this case in HTTP/2.
- # Really we should test it here too, but this'll do in the meantime.
- raise NotConnected() from None
+ await self._send_request(request, timeout)
task, args = self._send_request_data, [request.stream(), timeout]
async with self.backend.background_manager(task, args=args):
@property
def is_closed(self) -> bool:
return self.h11_state.our_state in (h11.CLOSED, h11.ERROR)
+
+ def is_connection_dropped(self) -> bool:
+ return self.reader.is_connection_dropped()
from ..concurrency import TimeoutFlag
from ..config import TimeoutConfig, TimeoutTypes
-from ..exceptions import NotConnected
from ..interfaces import BaseReader, BaseWriter, ConcurrencyBackend
from ..models import AsyncRequest, AsyncResponse
if not self.initialized:
self.initiate_connection()
- try:
- stream_id = await self.send_headers(request, timeout)
- except ConnectionResetError:
- raise NotConnected() from None
+ stream_id = await self.send_headers(request, timeout)
self.events[stream_id] = []
self.timeout_flags[stream_id] = TimeoutFlag()
@property
def is_closed(self) -> bool:
return False
+
+ def is_connection_dropped(self) -> bool:
+ return self.reader.is_connection_dropped()
# HTTP exceptions...
-class NotConnected(Exception):
- """
- A connection was lost at the point of starting a request,
- prior to any writes succeeding.
- """
-
-
class HttpError(Exception):
"""
An HTTP error occurred.
) -> bytes:
raise NotImplementedError() # pragma: no cover
+ def is_connection_dropped(self) -> bool:
+ raise NotImplementedError() # pragma: no cover
+
class BaseWriter:
"""
await response.close()
assert len(http.active_connections) == 0
assert len(http.keepalive_connections) == 0
+
+
+@pytest.mark.asyncio
+async def test_keepalive_connection_closed_by_server_is_reestablished(server):
+ """
+ Upon keep-alive connection closed by remote a new connection should be reestablished.
+ """
+ async with httpx.ConnectionPool() as http:
+ response = await http.request("GET", "http://127.0.0.1:8000/")
+ await response.read()
+
+ await server.shutdown() # shutdown the server to close the keep-alive connection
+ await server.startup()
+
+ response = await http.request("GET", "http://127.0.0.1:8000/")
+ await response.read()
+ assert len(http.active_connections) == 0
+ assert len(http.keepalive_connections) == 1
+
+
+@pytest.mark.asyncio
+async def test_keepalive_http2_connection_closed_by_server_is_reestablished(server):
+ """
+ Upon keep-alive connection closed by remote a new connection should be reestablished.
+ """
+ async with httpx.ConnectionPool() as http:
+ response = await http.request("GET", "http://127.0.0.1:8000/")
+ await response.read()
+
+ await server.shutdown() # shutdown the server to close the keep-alive connection
+ await server.startup()
+
+ response = await http.request("GET", "http://127.0.0.1:8000/")
+ await response.read()
+ assert len(http.active_connections) == 0
+ assert len(http.keepalive_connections) == 1
with Client(backend=backend) as client:
response_1 = client.get("http://example.org/1")
- backend.server.raise_disconnect = True
+ backend.server.close_connection = True
response_2 = client.get("http://example.org/2")
assert response_1.status_code == 200
self.app = app
self.buffer = b""
self.requests = {}
- self.raise_disconnect = False
+ self.close_connection = False
# BaseReader interface
# BaseWriter interface
def write_no_block(self, data: bytes) -> None:
- if self.raise_disconnect:
- self.raise_disconnect = False
- raise ConnectionResetError()
events = self.conn.receive_data(data)
self.buffer += self.conn.data_to_send()
for event in events:
async def close(self) -> None:
pass
+ def is_connection_dropped(self) -> bool:
+ return self.close_connection
+
# Server implementation
def request_received(self, headers, stream_id):
CertTypes,
Client,
Dispatcher,
- multipart,
Request,
Response,
TimeoutTypes,
VerifyTypes,
+ multipart,
)