From: Tom Christie Date: Wed, 17 Apr 2019 16:38:03 +0000 (+0100) Subject: Add PoolTimeout, and timeout tests X-Git-Tag: 0.1.0~6^2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=148818c212a0c486e39dba0d56993564ffb08857;p=thirdparty%2Fhttpx.git Add PoolTimeout, and timeout tests --- diff --git a/httpcore/__init__.py b/httpcore/__init__.py index 6630a361..7675068c 100644 --- a/httpcore/__init__.py +++ b/httpcore/__init__.py @@ -1,6 +1,13 @@ from .config import PoolLimits, SSLConfig, TimeoutConfig from .datastructures import URL, Request, Response -from .exceptions import ResponseClosed, StreamConsumed +from .exceptions import ( + ConnectTimeout, + PoolTimeout, + ReadTimeout, + ResponseClosed, + StreamConsumed, + Timeout, +) from .pool import ConnectionPool __version__ = "0.0.3" diff --git a/httpcore/compat.py b/httpcore/compat.py index 5369c5e8..794cf6f1 100644 --- a/httpcore/compat.py +++ b/httpcore/compat.py @@ -1,4 +1,4 @@ try: import brotli -except ImportError: - brotli = None # pragma: nocover +except ImportError: # pragma: nocover + brotli = None diff --git a/httpcore/pool.py b/httpcore/pool.py index fcec56ea..39216c3d 100644 --- a/httpcore/pool.py +++ b/httpcore/pool.py @@ -16,18 +16,23 @@ from .config import ( ) from .connections import Connection from .datastructures import URL, Request, Response +from .exceptions import PoolTimeout ConnectionKey = typing.Tuple[str, str, int] # (scheme, host, port) class ConnectionSemaphore: - def __init__(self, max_connections: int = None): + def __init__(self, max_connections: int = None, timeout: float = None): + self.timeout = timeout if max_connections is not None: self.semaphore = asyncio.BoundedSemaphore(value=max_connections) async def acquire(self) -> None: if hasattr(self, "semaphore"): - await self.semaphore.acquire() + try: + await asyncio.wait_for(self.semaphore.acquire(), self.timeout) + except asyncio.TimeoutError: + raise PoolTimeout() def release(self) -> None: if hasattr(self, "semaphore"): @@ -52,7 +57,7 @@ class ConnectionPool: {} ) # type: typing.Dict[ConnectionKey, typing.List[Connection]] self._connection_semaphore = ConnectionSemaphore( - max_connections=self.limits.hard_limit + max_connections=self.limits.hard_limit, timeout=self.timeout.pool_timeout ) async def request( diff --git a/tests/conftest.py b/tests/conftest.py index efb79df1..3ceaeec0 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -7,6 +7,25 @@ from uvicorn.main import Server async def app(scope, receive, send): assert scope["type"] == "http" + if scope["path"] == "/slow_response": + await slow_response(scope, receive, send) + else: + await hello_world(scope, receive, send) + + +async def hello_world(scope, receive, send): + await send( + { + "type": "http.response.start", + "status": 200, + "headers": [[b"content-type", b"text/plain"]], + } + ) + await send({"type": "http.response.body", "body": b"Hello, world!"}) + + +async def slow_response(scope, receive, send): + await asyncio.sleep(0.01) await send( { "type": "http.response.start", @@ -28,5 +47,4 @@ async def server(): yield server finally: server.should_exit = True - server.force_exit = True await task diff --git a/tests/test_timeouts.py b/tests/test_timeouts.py new file mode 100644 index 00000000..61ce9a18 --- /dev/null +++ b/tests/test_timeouts.py @@ -0,0 +1,26 @@ +import pytest + +import httpcore + + +@pytest.mark.asyncio +async def test_read_timeout(server): + timeout = httpcore.TimeoutConfig(read_timeout=0.0001) + + async with httpcore.ConnectionPool(timeout=timeout) as http: + with pytest.raises(httpcore.ReadTimeout): + await http.request("GET", "http://127.0.0.1:8000/slow_response") + + +@pytest.mark.asyncio +async def test_pool_timeout(server): + timeout = httpcore.TimeoutConfig(pool_timeout=0.0001) + limits = httpcore.PoolLimits(hard_limit=1) + + async with httpcore.ConnectionPool(timeout=timeout, limits=limits) as http: + response = await http.request("GET", "http://127.0.0.1:8000/", stream=True) + + with pytest.raises(httpcore.PoolTimeout): + await http.request("GET", "http://localhost:8000/") + + await response.read()