From: Tom Christie Date: Thu, 25 Apr 2019 11:57:18 +0000 (+0100) Subject: Add PoolSemaphore X-Git-Tag: 0.3.0~66^2~16 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=286f04f1a641473d124c5e8403de24b323861054;p=thirdparty%2Fhttpx.git Add PoolSemaphore --- diff --git a/httpcore/connection_pool.py b/httpcore/connection_pool.py index 6ec30289..894212ab 100644 --- a/httpcore/connection_pool.py +++ b/httpcore/connection_pool.py @@ -1,4 +1,3 @@ -import asyncio import typing from .config import ( @@ -13,6 +12,7 @@ from .config import ( from .connection import HTTPConnection from .exceptions import PoolTimeout from .models import Client, Origin, Request, Response +from .streams import PoolSemaphore class ConnectionPool(Client): @@ -32,9 +32,7 @@ class ConnectionPool(Client): self._keepalive_connections = ( {} ) # type: typing.Dict[Origin, typing.List[HTTPConnection]] - self._max_connections = ConnectionSemaphore( - max_connections=self.limits.hard_limit - ) + self._max_connections = PoolSemaphore(limits, timeout) async def send( self, @@ -62,15 +60,7 @@ class ConnectionPool(Client): self.num_active_connections += 1 except (KeyError, IndexError): - if timeout is None: - pool_timeout = self.timeout.pool_timeout - else: - pool_timeout = timeout.pool_timeout - - try: - await asyncio.wait_for(self._max_connections.acquire(), pool_timeout) - except asyncio.TimeoutError: - raise PoolTimeout() + await self._max_connections.acquire(timeout) connection = HTTPConnection( origin, ssl=self.ssl, @@ -108,25 +98,3 @@ class ConnectionPool(Client): self._keepalive_connections.clear() for connection in all_connections: await connection.close() - - -class ConnectionSemaphore: - def __init__(self, max_connections: int = None): - self.max_connections = max_connections - - @property - def semaphore(self) -> typing.Optional[asyncio.BoundedSemaphore]: - if not hasattr(self, "_semaphore"): - if self.max_connections is None: - self._semaphore = None - else: - self._semaphore = asyncio.BoundedSemaphore(value=self.max_connections) - return self._semaphore - - async def acquire(self) -> None: - if self.semaphore is not None: - await self.semaphore.acquire() - - def release(self) -> None: - if self.semaphore is not None: - self.semaphore.release() diff --git a/httpcore/streams.py b/httpcore/streams.py index 5a9a0abb..cba51fd7 100644 --- a/httpcore/streams.py +++ b/httpcore/streams.py @@ -2,7 +2,9 @@ The `Reader` and `Writer` classes here provide a lightweight layer over `asyncio.StreamReader` and `asyncio.StreamWriter`. -They help encapsulate the timeout logic, make it easier to unit-test +Similarly `PoolSemaphore` is a lightweight layer over `BoundedSemaphore`. + +These classes help encapsulate the timeout logic, make it easier to unit-test protocols, and help keep the rest of the package more `async`/`await` based, and less strictly `asyncio`-specific. """ @@ -11,8 +13,8 @@ import enum import ssl import typing -from .config import TimeoutConfig, DEFAULT_TIMEOUT_CONFIG -from .exceptions import ConnectTimeout, ReadTimeout, WriteTimeout +from .config import DEFAULT_TIMEOUT_CONFIG, PoolLimits, TimeoutConfig +from .exceptions import ConnectTimeout, ReadTimeout, PoolTimeout, WriteTimeout OptionalTimeout = typing.Optional[TimeoutConfig] @@ -38,6 +40,17 @@ class BaseWriter: raise NotImplementedError() # pragma: no cover +class BasePoolSemaphore: + def __init__(self, limits: PoolLimits, timeout: TimeoutConfig): + raise NotImplementedError() # pragma: no cover + + async def acquire(self, timeout: OptionalTimeout = None) -> None: + raise NotImplementedError() # pragma: no cover + + def release(self) -> None: + raise NotImplementedError() # pragma: no cover + + class Reader(BaseReader): def __init__( self, stream_reader: asyncio.StreamReader, timeout: TimeoutConfig @@ -86,6 +99,40 @@ class Writer(BaseWriter): self.stream_writer.close() +class PoolSemaphore(BasePoolSemaphore): + def __init__(self, limits: PoolLimits, timeout: TimeoutConfig): + self.limits = limits + self.timeout = timeout + + @property + def semaphore(self) -> typing.Optional[asyncio.BoundedSemaphore]: + if not hasattr(self, "_semaphore"): + max_connections = self.limits.hard_limit + if max_connections is None: + self._semaphore = None + else: + self._semaphore = asyncio.BoundedSemaphore(value=max_connections) + return self._semaphore + + async def acquire(self, timeout: OptionalTimeout = None) -> None: + if self.semaphore is None: + return + + if timeout is None: + timeout = self.timeout + + try: + await asyncio.wait_for(self.semaphore.acquire(), timeout.pool_timeout) + except asyncio.TimeoutError: + raise PoolTimeout() + + def release(self) -> None: + if self.semaphore is None: + return + + self.semaphore.release() + + async def connect( hostname: str, port: int, diff --git a/tests/test_http2.py b/tests/test_http2.py index fbdb52d8..f17d0f98 100644 --- a/tests/test_http2.py +++ b/tests/test_http2.py @@ -1,8 +1,9 @@ +import json + import h2.config import h2.connection import h2.events import pytest -import json import httpcore @@ -61,15 +62,15 @@ class MockServer(httpcore.BaseReader, httpcore.BaseWriter): request_headers = dict(request["headers"]) request_data = request["data"] - response_body = json.dumps({ - "method": request_headers[b":method"].decode(), - "path": request_headers[b":path"].decode(), - "body": request_data.decode() - }).encode() + response_body = json.dumps( + { + "method": request_headers[b":method"].decode(), + "path": request_headers[b":path"].decode(), + "body": request_data.decode(), + } + ).encode() - response_headers = ( - (b":status", b"200"), - ) + response_headers = ((b":status", b"200"),) self.conn.send_headers(stream_id, response_headers) self.conn.send_data(stream_id, response_body, end_stream=True) self.buffer += self.conn.data_to_send() @@ -79,7 +80,9 @@ class MockServer(httpcore.BaseReader, httpcore.BaseWriter): async def test_http2_get_request(): server = MockServer() origin = httpcore.Origin("http://example.org") - async with httpcore.HTTP2Connection(reader=server, writer=server, origin=origin) as client: + async with httpcore.HTTP2Connection( + reader=server, writer=server, origin=origin + ) as client: response = await client.request("GET", "http://example.org") assert response.status_code == 200 assert json.loads(response.body) == {"method": "GET", "path": "/", "body": ""} @@ -89,17 +92,25 @@ async def test_http2_get_request(): async def test_http2_post_request(): server = MockServer() origin = httpcore.Origin("http://example.org") - async with httpcore.HTTP2Connection(reader=server, writer=server, origin=origin) as client: + async with httpcore.HTTP2Connection( + reader=server, writer=server, origin=origin + ) as client: response = await client.request("POST", "http://example.org", body=b"") assert response.status_code == 200 - assert json.loads(response.body) == {"method": "POST", "path": "/", "body": ""} + assert json.loads(response.body) == { + "method": "POST", + "path": "/", + "body": "", + } @pytest.mark.asyncio async def test_http2_multiple_requests(): server = MockServer() origin = httpcore.Origin("http://example.org") - async with httpcore.HTTP2Connection(reader=server, writer=server, origin=origin) as client: + async with httpcore.HTTP2Connection( + reader=server, writer=server, origin=origin + ) as client: response_1 = await client.request("GET", "http://example.org/1") response_2 = await client.request("GET", "http://example.org/2") response_3 = await client.request("GET", "http://example.org/3")