]> git.ipfire.org Git - thirdparty/httpx.git/commitdiff
Keep-alive timeouts. (#627)
authorTom Christie <tom@tomchristie.com>
Thu, 12 Dec 2019 11:52:49 +0000 (11:52 +0000)
committerGitHub <noreply@github.com>
Thu, 12 Dec 2019 11:52:49 +0000 (11:52 +0000)
* Add .time() to backend

* Add connection timeouts

* Add test case for keep alive timeouts

* Update httpx/dispatch/connection_pool.py

Co-Authored-By: Florimond Manca <florimond.manca@gmail.com>
* Cleanups from review

* Use .expires_at, rather than .timeout_at

httpx/concurrency/asyncio.py
httpx/concurrency/auto.py
httpx/concurrency/base.py
httpx/concurrency/trio.py
httpx/dispatch/connection.py
httpx/dispatch/connection_pool.py
tests/dispatch/test_connection_pools.py

index c73a13824bb1d7ca9863ea90b4628a14cddd964e..bb795b1d626f0e39594ed40fb86f8c57fcdeb226 100644 (file)
@@ -242,6 +242,10 @@ class AsyncioBackend(ConcurrencyBackend):
 
         return SocketStream(stream_reader=stream_reader, stream_writer=stream_writer)
 
+    def time(self) -> float:
+        loop = asyncio.get_event_loop()
+        return loop.time()
+
     async def run_in_threadpool(
         self, func: typing.Callable, *args: typing.Any, **kwargs: typing.Any
     ) -> typing.Any:
index c90ce2410a86d727c92a8b99d8afb483b0c4d1dd..c11c06469ecc1c7f41a55895bd938bdc128823b8 100644 (file)
@@ -41,6 +41,9 @@ class AutoBackend(ConcurrencyBackend):
     ) -> BaseSocketStream:
         return await self.backend.open_uds_stream(path, hostname, ssl_context, timeout)
 
+    def time(self) -> float:
+        return self.backend.time()
+
     def get_semaphore(self, max_value: int) -> BasePoolSemaphore:
         return self.backend.get_semaphore(max_value)
 
index da6dd986e046a54c68f2c14b5399946111779d4f..366e5749a0e72e8d2ab2dadb3169efec916fe9f5 100644 (file)
@@ -99,6 +99,9 @@ class ConcurrencyBackend:
     ) -> BaseSocketStream:
         raise NotImplementedError()  # pragma: no cover
 
+    def time(self) -> float:
+        raise NotImplementedError()  # pragma: no cover
+
     def get_semaphore(self, max_value: int) -> BasePoolSemaphore:
         raise NotImplementedError()  # pragma: no cover
 
index 6f2ab2d1890a773b504a5ea75df128dd7d1fac71..0c64988dbee852fe2f76ed862fe19acced54c80d 100644 (file)
@@ -156,6 +156,9 @@ class TrioBackend(ConcurrencyBackend):
             functools.partial(coroutine, **kwargs) if kwargs else coroutine, *args
         )
 
+    def time(self) -> float:
+        return trio.current_time()
+
     def get_semaphore(self, max_value: int) -> BasePoolSemaphore:
         return PoolSemaphore(max_value)
 
index 009895af2fde1d7b6036a878388c35f54937e4d0..15ed455b35f93a0633eeaae57cb9917fc9a8a378 100644 (file)
@@ -38,6 +38,7 @@ class HTTPConnection(Dispatcher):
         self.release_func = release_func
         self.uds = uds
         self.open_connection: typing.Optional[OpenConnection] = None
+        self.expires_at: typing.Optional[float] = None
 
     async def send(
         self,
index 4fe3998814d0380b47e2778620d7142219dd56be..db23640b574839bc558053282611436500766529 100644 (file)
@@ -78,6 +78,8 @@ class ConnectionStore:
 
 
 class ConnectionPool(Dispatcher):
+    KEEP_ALIVE_EXPIRY = 5.0
+
     def __init__(
         self,
         *,
@@ -101,6 +103,7 @@ class ConnectionPool(Dispatcher):
         self.active_connections = ConnectionStore()
 
         self.backend = lookup_backend(backend)
+        self.next_keepalive_check = 0.0
 
     @property
     def max_connections(self) -> BasePoolSemaphore:
@@ -118,6 +121,21 @@ class ConnectionPool(Dispatcher):
     def num_connections(self) -> int:
         return len(self.keepalive_connections) + len(self.active_connections)
 
+    async def check_keepalive_expiry(self) -> None:
+        now = self.backend.time()
+        if now < self.next_keepalive_check:
+            return
+        self.next_keepalive_check = now + 1.0
+
+        # Iterate through all the keep alive connections.
+        # We create a list here to avoid any 'changed during iteration' errors.
+        keepalives = list(self.keepalive_connections.all.keys())
+        for connection in keepalives:
+            if connection.expires_at is not None and now > connection.expires_at:
+                self.keepalive_connections.remove(connection)
+                self.max_connections.release()
+                await connection.close()
+
     async def send(
         self,
         request: Request,
@@ -125,6 +143,7 @@ class ConnectionPool(Dispatcher):
         cert: CertTypes = None,
         timeout: Timeout = None,
     ) -> Response:
+        await self.check_keepalive_expiry()
         connection = await self.acquire_connection(
             origin=request.url.origin, timeout=timeout
         )
@@ -180,6 +199,8 @@ class ConnectionPool(Dispatcher):
             self.max_connections.release()
             await connection.close()
         else:
+            now = self.backend.time()
+            connection.expires_at = now + self.KEEP_ALIVE_EXPIRY
             self.active_connections.remove(connection)
             self.keepalive_connections.add(connection)
 
index f10b16ab1451dcc85d21bcf924c97abe2c8ea270..447b8d667f7a780deb09913d84795f84186bd216 100644 (file)
@@ -18,6 +18,37 @@ async def test_keepalive_connections(server, backend):
         assert len(http.keepalive_connections) == 1
 
 
+async def test_keepalive_timeout(server, backend):
+    """
+    Keep-alive connections should timeout.
+    """
+    async with ConnectionPool() as http:
+        response = await http.request("GET", server.url)
+        await response.read()
+        assert len(http.active_connections) == 0
+        assert len(http.keepalive_connections) == 1
+
+        http.next_keepalive_check = 0.0
+        await http.check_keepalive_expiry()
+
+        assert len(http.active_connections) == 0
+        assert len(http.keepalive_connections) == 1
+
+    async with ConnectionPool() as http:
+        http.KEEP_ALIVE_EXPIRY = 0.0
+
+        response = await http.request("GET", server.url)
+        await response.read()
+        assert len(http.active_connections) == 0
+        assert len(http.keepalive_connections) == 1
+
+        http.next_keepalive_check = 0.0
+        await http.check_keepalive_expiry()
+
+        assert len(http.active_connections) == 0
+        assert len(http.keepalive_connections) == 0
+
+
 async def test_differing_connection_keys(server, backend):
     """
     Connections to differing connection keys should result in multiple connections.