From: Tom Christie Date: Thu, 12 Dec 2019 11:52:49 +0000 (+0000) Subject: Keep-alive timeouts. (#627) X-Git-Tag: 0.9.4~2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=499de51f2b4772bff5dd3955474d349a0e9324e4;p=thirdparty%2Fhttpx.git Keep-alive timeouts. (#627) * Add .time() to backend * Add connection timeouts * Add test case for keep alive timeouts * Update httpx/dispatch/connection_pool.py Co-Authored-By: Florimond Manca * Cleanups from review * Use .expires_at, rather than .timeout_at --- diff --git a/httpx/concurrency/asyncio.py b/httpx/concurrency/asyncio.py index c73a1382..bb795b1d 100644 --- a/httpx/concurrency/asyncio.py +++ b/httpx/concurrency/asyncio.py @@ -242,6 +242,10 @@ class AsyncioBackend(ConcurrencyBackend): return SocketStream(stream_reader=stream_reader, stream_writer=stream_writer) + def time(self) -> float: + loop = asyncio.get_event_loop() + return loop.time() + async def run_in_threadpool( self, func: typing.Callable, *args: typing.Any, **kwargs: typing.Any ) -> typing.Any: diff --git a/httpx/concurrency/auto.py b/httpx/concurrency/auto.py index c90ce241..c11c0646 100644 --- a/httpx/concurrency/auto.py +++ b/httpx/concurrency/auto.py @@ -41,6 +41,9 @@ class AutoBackend(ConcurrencyBackend): ) -> BaseSocketStream: return await self.backend.open_uds_stream(path, hostname, ssl_context, timeout) + def time(self) -> float: + return self.backend.time() + def get_semaphore(self, max_value: int) -> BasePoolSemaphore: return self.backend.get_semaphore(max_value) diff --git a/httpx/concurrency/base.py b/httpx/concurrency/base.py index da6dd986..366e5749 100644 --- a/httpx/concurrency/base.py +++ b/httpx/concurrency/base.py @@ -99,6 +99,9 @@ class ConcurrencyBackend: ) -> BaseSocketStream: raise NotImplementedError() # pragma: no cover + def time(self) -> float: + raise NotImplementedError() # pragma: no cover + def get_semaphore(self, max_value: int) -> BasePoolSemaphore: raise NotImplementedError() # pragma: no cover diff --git a/httpx/concurrency/trio.py b/httpx/concurrency/trio.py index 6f2ab2d1..0c64988d 100644 --- a/httpx/concurrency/trio.py +++ b/httpx/concurrency/trio.py @@ -156,6 +156,9 @@ class TrioBackend(ConcurrencyBackend): functools.partial(coroutine, **kwargs) if kwargs else coroutine, *args ) + def time(self) -> float: + return trio.current_time() + def get_semaphore(self, max_value: int) -> BasePoolSemaphore: return PoolSemaphore(max_value) diff --git a/httpx/dispatch/connection.py b/httpx/dispatch/connection.py index 009895af..15ed455b 100644 --- a/httpx/dispatch/connection.py +++ b/httpx/dispatch/connection.py @@ -38,6 +38,7 @@ class HTTPConnection(Dispatcher): self.release_func = release_func self.uds = uds self.open_connection: typing.Optional[OpenConnection] = None + self.expires_at: typing.Optional[float] = None async def send( self, diff --git a/httpx/dispatch/connection_pool.py b/httpx/dispatch/connection_pool.py index 4fe39988..db23640b 100644 --- a/httpx/dispatch/connection_pool.py +++ b/httpx/dispatch/connection_pool.py @@ -78,6 +78,8 @@ class ConnectionStore: class ConnectionPool(Dispatcher): + KEEP_ALIVE_EXPIRY = 5.0 + def __init__( self, *, @@ -101,6 +103,7 @@ class ConnectionPool(Dispatcher): self.active_connections = ConnectionStore() self.backend = lookup_backend(backend) + self.next_keepalive_check = 0.0 @property def max_connections(self) -> BasePoolSemaphore: @@ -118,6 +121,21 @@ class ConnectionPool(Dispatcher): def num_connections(self) -> int: return len(self.keepalive_connections) + len(self.active_connections) + async def check_keepalive_expiry(self) -> None: + now = self.backend.time() + if now < self.next_keepalive_check: + return + self.next_keepalive_check = now + 1.0 + + # Iterate through all the keep alive connections. + # We create a list here to avoid any 'changed during iteration' errors. + keepalives = list(self.keepalive_connections.all.keys()) + for connection in keepalives: + if connection.expires_at is not None and now > connection.expires_at: + self.keepalive_connections.remove(connection) + self.max_connections.release() + await connection.close() + async def send( self, request: Request, @@ -125,6 +143,7 @@ class ConnectionPool(Dispatcher): cert: CertTypes = None, timeout: Timeout = None, ) -> Response: + await self.check_keepalive_expiry() connection = await self.acquire_connection( origin=request.url.origin, timeout=timeout ) @@ -180,6 +199,8 @@ class ConnectionPool(Dispatcher): self.max_connections.release() await connection.close() else: + now = self.backend.time() + connection.expires_at = now + self.KEEP_ALIVE_EXPIRY self.active_connections.remove(connection) self.keepalive_connections.add(connection) diff --git a/tests/dispatch/test_connection_pools.py b/tests/dispatch/test_connection_pools.py index f10b16ab..447b8d66 100644 --- a/tests/dispatch/test_connection_pools.py +++ b/tests/dispatch/test_connection_pools.py @@ -18,6 +18,37 @@ async def test_keepalive_connections(server, backend): assert len(http.keepalive_connections) == 1 +async def test_keepalive_timeout(server, backend): + """ + Keep-alive connections should timeout. + """ + async with ConnectionPool() as http: + response = await http.request("GET", server.url) + await response.read() + assert len(http.active_connections) == 0 + assert len(http.keepalive_connections) == 1 + + http.next_keepalive_check = 0.0 + await http.check_keepalive_expiry() + + assert len(http.active_connections) == 0 + assert len(http.keepalive_connections) == 1 + + async with ConnectionPool() as http: + http.KEEP_ALIVE_EXPIRY = 0.0 + + response = await http.request("GET", server.url) + await response.read() + assert len(http.active_connections) == 0 + assert len(http.keepalive_connections) == 1 + + http.next_keepalive_check = 0.0 + await http.check_keepalive_expiry() + + assert len(http.active_connections) == 0 + assert len(http.keepalive_connections) == 0 + + async def test_differing_connection_keys(server, backend): """ Connections to differing connection keys should result in multiple connections.