return SocketStream(stream_reader=stream_reader, stream_writer=stream_writer)
+ def time(self) -> float:
+ loop = asyncio.get_event_loop()
+ return loop.time()
+
async def run_in_threadpool(
self, func: typing.Callable, *args: typing.Any, **kwargs: typing.Any
) -> typing.Any:
) -> BaseSocketStream:
return await self.backend.open_uds_stream(path, hostname, ssl_context, timeout)
+ def time(self) -> float:
+ return self.backend.time()
+
def get_semaphore(self, max_value: int) -> BasePoolSemaphore:
return self.backend.get_semaphore(max_value)
) -> BaseSocketStream:
raise NotImplementedError() # pragma: no cover
+ def time(self) -> float:
+ raise NotImplementedError() # pragma: no cover
+
def get_semaphore(self, max_value: int) -> BasePoolSemaphore:
raise NotImplementedError() # pragma: no cover
functools.partial(coroutine, **kwargs) if kwargs else coroutine, *args
)
+ def time(self) -> float:
+ return trio.current_time()
+
def get_semaphore(self, max_value: int) -> BasePoolSemaphore:
return PoolSemaphore(max_value)
self.release_func = release_func
self.uds = uds
self.open_connection: typing.Optional[OpenConnection] = None
+ self.expires_at: typing.Optional[float] = None
async def send(
self,
class ConnectionPool(Dispatcher):
+ KEEP_ALIVE_EXPIRY = 5.0
+
def __init__(
self,
*,
self.active_connections = ConnectionStore()
self.backend = lookup_backend(backend)
+ self.next_keepalive_check = 0.0
@property
def max_connections(self) -> BasePoolSemaphore:
def num_connections(self) -> int:
return len(self.keepalive_connections) + len(self.active_connections)
+ async def check_keepalive_expiry(self) -> None:
+ now = self.backend.time()
+ if now < self.next_keepalive_check:
+ return
+ self.next_keepalive_check = now + 1.0
+
+ # Iterate through all the keep alive connections.
+ # We create a list here to avoid any 'changed during iteration' errors.
+ keepalives = list(self.keepalive_connections.all.keys())
+ for connection in keepalives:
+ if connection.expires_at is not None and now > connection.expires_at:
+ self.keepalive_connections.remove(connection)
+ self.max_connections.release()
+ await connection.close()
+
async def send(
self,
request: Request,
cert: CertTypes = None,
timeout: Timeout = None,
) -> Response:
+ await self.check_keepalive_expiry()
connection = await self.acquire_connection(
origin=request.url.origin, timeout=timeout
)
self.max_connections.release()
await connection.close()
else:
+ now = self.backend.time()
+ connection.expires_at = now + self.KEEP_ALIVE_EXPIRY
self.active_connections.remove(connection)
self.keepalive_connections.add(connection)
assert len(http.keepalive_connections) == 1
+async def test_keepalive_timeout(server, backend):
+ """
+ Keep-alive connections should timeout.
+ """
+ async with ConnectionPool() as http:
+ response = await http.request("GET", server.url)
+ await response.read()
+ assert len(http.active_connections) == 0
+ assert len(http.keepalive_connections) == 1
+
+ http.next_keepalive_check = 0.0
+ await http.check_keepalive_expiry()
+
+ assert len(http.active_connections) == 0
+ assert len(http.keepalive_connections) == 1
+
+ async with ConnectionPool() as http:
+ http.KEEP_ALIVE_EXPIRY = 0.0
+
+ response = await http.request("GET", server.url)
+ await response.read()
+ assert len(http.active_connections) == 0
+ assert len(http.keepalive_connections) == 1
+
+ http.next_keepalive_check = 0.0
+ await http.check_keepalive_expiry()
+
+ assert len(http.active_connections) == 0
+ assert len(http.keepalive_connections) == 0
+
+
async def test_differing_connection_keys(server, backend):
"""
Connections to differing connection keys should result in multiple connections.