From: Florimond Manca Date: Thu, 29 Aug 2019 21:17:14 +0000 (+0200) Subject: Serve test server in thread (#292) X-Git-Tag: 0.7.3~29 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=994403bec920c923a6c3970215e75d6a99385e0c;p=thirdparty%2Fhttpx.git Serve test server in thread (#292) --- diff --git a/httpx/concurrency/asyncio.py b/httpx/concurrency/asyncio.py index 8fc625f1..6af8f9c5 100644 --- a/httpx/concurrency/asyncio.py +++ b/httpx/concurrency/asyncio.py @@ -113,7 +113,7 @@ class Stream(BaseStream): ) break except asyncio.TimeoutError: - # We check our flag at the possible moment, in order to + # We check our flag at the first possible moment, in order to # allow us to suppress write timeouts, if we've since # switched over to read-timeout mode. should_raise = flag is None or flag.raise_on_write_timeout diff --git a/tests/client/test_client.py b/tests/client/test_client.py index ac558f5b..52d99e3b 100644 --- a/tests/client/test_client.py +++ b/tests/client/test_client.py @@ -1,29 +1,8 @@ -import asyncio -import functools - import pytest import httpx -def threadpool(func): - """ - Our sync tests should run in separate thread to the uvicorn server. - """ - - @functools.wraps(func) - async def wrapped(*args, **kwargs): - nonlocal func - - loop = asyncio.get_event_loop() - if kwargs: - func = functools.partial(func, **kwargs) - await loop.run_in_executor(None, func, *args) - - return pytest.mark.asyncio(wrapped) - - -@threadpool def test_get(server): url = "http://127.0.0.1:8000/" with httpx.Client() as http: @@ -40,7 +19,6 @@ def test_get(server): assert repr(response) == "" -@threadpool def test_post(server): with httpx.Client() as http: response = http.post("http://127.0.0.1:8000/", data=b"Hello, world!") @@ -48,7 +26,6 @@ def test_post(server): assert response.reason_phrase == "OK" -@threadpool def test_post_json(server): with httpx.Client() as http: response = http.post("http://127.0.0.1:8000/", json={"text": "Hello, world!"}) @@ -56,7 +33,6 @@ def test_post_json(server): assert response.reason_phrase == "OK" -@threadpool def test_stream_response(server): with httpx.Client() as http: response = http.get("http://127.0.0.1:8000/", stream=True) @@ -65,7 +41,6 @@ def test_stream_response(server): assert content == b"Hello, world!" -@threadpool def test_stream_iterator(server): with httpx.Client() as http: response = http.get("http://127.0.0.1:8000/", stream=True) @@ -76,7 +51,6 @@ def test_stream_iterator(server): assert body == b"Hello, world!" -@threadpool def test_raw_iterator(server): with httpx.Client() as http: response = http.get("http://127.0.0.1:8000/", stream=True) @@ -88,7 +62,6 @@ def test_raw_iterator(server): response.close() # TODO: should Response be available as context managers? -@threadpool def test_raise_for_status(server): with httpx.Client() as client: for status_code in (200, 400, 404, 500, 505): @@ -103,7 +76,6 @@ def test_raise_for_status(server): assert response.raise_for_status() is None -@threadpool def test_options(server): with httpx.Client() as http: response = http.options("http://127.0.0.1:8000/") @@ -111,7 +83,6 @@ def test_options(server): assert response.reason_phrase == "OK" -@threadpool def test_head(server): with httpx.Client() as http: response = http.head("http://127.0.0.1:8000/") @@ -119,7 +90,6 @@ def test_head(server): assert response.reason_phrase == "OK" -@threadpool def test_put(server): with httpx.Client() as http: response = http.put("http://127.0.0.1:8000/", data=b"Hello, world!") @@ -127,7 +97,6 @@ def test_put(server): assert response.reason_phrase == "OK" -@threadpool def test_patch(server): with httpx.Client() as http: response = http.patch("http://127.0.0.1:8000/", data=b"Hello, world!") @@ -135,7 +104,6 @@ def test_patch(server): assert response.reason_phrase == "OK" -@threadpool def test_delete(server): with httpx.Client() as http: response = http.delete("http://127.0.0.1:8000/") @@ -143,7 +111,6 @@ def test_delete(server): assert response.reason_phrase == "OK" -@threadpool def test_base_url(server): base_url = "http://127.0.0.1:8000/" with httpx.Client(base_url=base_url) as http: diff --git a/tests/conftest.py b/tests/conftest.py index 2974da41..ec75e600 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,4 +1,6 @@ import asyncio +import threading +import time import pytest import trustme @@ -101,76 +103,119 @@ class CAWithPKEncryption(trustme.CA): ) -@pytest.fixture +SERVER_SCOPE = "session" + + +@pytest.fixture(scope=SERVER_SCOPE) def example_cert(): ca = CAWithPKEncryption() ca.issue_cert("example.org") return ca -@pytest.fixture +@pytest.fixture(scope=SERVER_SCOPE) def cert_pem_file(example_cert): with example_cert.cert_pem.tempfile() as tmp: yield tmp -@pytest.fixture +@pytest.fixture(scope=SERVER_SCOPE) def cert_private_key_file(example_cert): with example_cert.private_key_pem.tempfile() as tmp: yield tmp -@pytest.fixture +@pytest.fixture(scope=SERVER_SCOPE) def cert_encrypted_private_key_file(example_cert): with example_cert.encrypted_private_key_pem.tempfile() as tmp: yield tmp +class TestServer(Server): + def install_signal_handlers(self) -> None: + # Disable the default installation of handlers for signals such as SIGTERM, + # because it can only be done in the main thread. + pass + + async def serve(self, sockets=None): + self.restart_requested = asyncio.Event() + + loop = asyncio.get_event_loop() + tasks = { + loop.create_task(super().serve(sockets=sockets)), + loop.create_task(self.watch_restarts()), + } + + await asyncio.wait(tasks) + + async def restart(self) -> None: + # Ensure we are in an asyncio environment. + assert asyncio.get_event_loop() is not None + # This may be called from a different thread than the one the server is + # running on. For this reason, we use an event to coordinate with the server + # instead of calling shutdown()/startup() directly. + self.restart_requested.set() + self.started = False + while not self.started: + await asyncio.sleep(0.5) + + async def watch_restarts(self): + while True: + if self.should_exit: + return + + try: + await asyncio.wait_for(self.restart_requested.wait(), timeout=0.1) + except asyncio.TimeoutError: + continue + + self.restart_requested.clear() + await self.shutdown() + await self.startup() + + @pytest.fixture -async def server(): - config = Config(app=app, lifespan="off") - server = Server(config=config) - task = asyncio.ensure_future(server.serve()) +def restart(backend): + """Restart the running server from an async test function. + + This fixture deals with possible differences between the environment of the + test function and that of the server. + """ + + async def restart(server): + await backend.run_in_threadpool(AsyncioBackend().run, server.restart) + + return restart + + +def serve_in_thread(server: Server): + thread = threading.Thread(target=server.run) + thread.start() try: while not server.started: - await asyncio.sleep(0.0001) + time.sleep(1e-3) yield server finally: server.should_exit = True - await task + thread.join() -@pytest.fixture -async def https_server(cert_pem_file, cert_private_key_file): +@pytest.fixture(scope=SERVER_SCOPE) +def server(): + config = Config(app=app, lifespan="off", loop="asyncio") + server = TestServer(config=config) + yield from serve_in_thread(server) + + +@pytest.fixture(scope=SERVER_SCOPE) +def https_server(cert_pem_file, cert_private_key_file): config = Config( app=app, lifespan="off", ssl_certfile=cert_pem_file, ssl_keyfile=cert_private_key_file, port=8001, + loop="asyncio", ) - server = Server(config=config) - task = asyncio.ensure_future(server.serve()) - try: - while not server.started: - await asyncio.sleep(0.0001) - yield server - finally: - server.should_exit = True - await task - - -@pytest.fixture -def restart(backend): - async def asyncio_restart(server): - await server.shutdown() - await server.startup() - - if isinstance(backend, AsyncioBackend): - return asyncio_restart - - # The uvicorn server runs under asyncio, so we will need to figure out - # how to restart it under a different I/O library. - # This will most likely require running `asyncio_restart` in the threadpool, - # but that might not be sufficient. - raise NotImplementedError + server = TestServer(config=config) + yield from serve_in_thread(server) diff --git a/tests/dispatch/test_connection_pools.py b/tests/dispatch/test_connection_pools.py index 4dd8424d..8f0ba70c 100644 --- a/tests/dispatch/test_connection_pools.py +++ b/tests/dispatch/test_connection_pools.py @@ -134,7 +134,7 @@ async def test_keepalive_connection_closed_by_server_is_reestablished( response = await http.request("GET", "http://127.0.0.1:8000/") await response.read() - # shutdown the server to close the keep-alive connection + # Shutdown the server to close the keep-alive connection await restart(server) response = await http.request("GET", "http://127.0.0.1:8000/") @@ -154,7 +154,7 @@ async def test_keepalive_http2_connection_closed_by_server_is_reestablished( response = await http.request("GET", "http://127.0.0.1:8000/") await response.read() - # shutdown the server to close the keep-alive connection + # Shutdown the server to close the keep-alive connection await restart(server) response = await http.request("GET", "http://127.0.0.1:8000/") diff --git a/tests/dispatch/test_connections.py b/tests/dispatch/test_connections.py index 59e9bb14..0aa77412 100644 --- a/tests/dispatch/test_connections.py +++ b/tests/dispatch/test_connections.py @@ -2,40 +2,42 @@ from httpx import HTTPConnection async def test_get(server, backend): - conn = HTTPConnection(origin="http://127.0.0.1:8000/", backend=backend) - response = await conn.request("GET", "http://127.0.0.1:8000/") - await response.read() - assert response.status_code == 200 - assert response.content == b"Hello, world!" + async with HTTPConnection(origin="http://127.0.0.1:8000/", backend=backend) as conn: + response = await conn.request("GET", "http://127.0.0.1:8000/") + await response.read() + assert response.status_code == 200 + assert response.content == b"Hello, world!" async def test_post(server, backend): - conn = HTTPConnection(origin="http://127.0.0.1:8000/", backend=backend) - response = await conn.request( - "GET", "http://127.0.0.1:8000/", data=b"Hello, world!" - ) - assert response.status_code == 200 + async with HTTPConnection(origin="http://127.0.0.1:8000/", backend=backend) as conn: + response = await conn.request( + "GET", "http://127.0.0.1:8000/", data=b"Hello, world!" + ) + assert response.status_code == 200 async def test_https_get_with_ssl_defaults(https_server, backend): """ An HTTPS request, with default SSL configuration set on the client. """ - conn = HTTPConnection( + async with HTTPConnection( origin="https://127.0.0.1:8001/", verify=False, backend=backend - ) - response = await conn.request("GET", "https://127.0.0.1:8001/") - await response.read() - assert response.status_code == 200 - assert response.content == b"Hello, world!" + ) as conn: + response = await conn.request("GET", "https://127.0.0.1:8001/") + await response.read() + assert response.status_code == 200 + assert response.content == b"Hello, world!" async def test_https_get_with_sll_overrides(https_server, backend): """ An HTTPS request, with SSL configuration set on the request. """ - conn = HTTPConnection(origin="https://127.0.0.1:8001/", backend=backend) - response = await conn.request("GET", "https://127.0.0.1:8001/", verify=False) - await response.read() - assert response.status_code == 200 - assert response.content == b"Hello, world!" + async with HTTPConnection( + origin="https://127.0.0.1:8001/", backend=backend + ) as conn: + response = await conn.request("GET", "https://127.0.0.1:8001/", verify=False) + await response.read() + assert response.status_code == 200 + assert response.content == b"Hello, world!" diff --git a/tests/test_api.py b/tests/test_api.py index de4267c7..cb9eb20c 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -1,29 +1,8 @@ -import asyncio -import functools - import pytest import httpx -def threadpool(func): - """ - Our sync tests should run in separate thread to the uvicorn server. - """ - - @functools.wraps(func) - async def wrapped(*args, **kwargs): - nonlocal func - - loop = asyncio.get_event_loop() - if kwargs: - func = functools.partial(func, **kwargs) - await loop.run_in_executor(None, func, *args) - - return pytest.mark.asyncio(wrapped) - - -@threadpool def test_get(server): response = httpx.get("http://127.0.0.1:8000/") assert response.status_code == 200 @@ -32,14 +11,12 @@ def test_get(server): assert response.http_version == "HTTP/1.1" -@threadpool def test_post(server): response = httpx.post("http://127.0.0.1:8000/", data=b"Hello, world!") assert response.status_code == 200 assert response.reason_phrase == "OK" -@threadpool def test_post_byte_iterator(server): def data(): yield b"Hello" @@ -51,42 +28,36 @@ def test_post_byte_iterator(server): assert response.reason_phrase == "OK" -@threadpool def test_options(server): response = httpx.options("http://127.0.0.1:8000/") assert response.status_code == 200 assert response.reason_phrase == "OK" -@threadpool def test_head(server): response = httpx.head("http://127.0.0.1:8000/") assert response.status_code == 200 assert response.reason_phrase == "OK" -@threadpool def test_put(server): response = httpx.put("http://127.0.0.1:8000/", data=b"Hello, world!") assert response.status_code == 200 assert response.reason_phrase == "OK" -@threadpool def test_patch(server): response = httpx.patch("http://127.0.0.1:8000/", data=b"Hello, world!") assert response.status_code == 200 assert response.reason_phrase == "OK" -@threadpool def test_delete(server): response = httpx.delete("http://127.0.0.1:8000/") assert response.status_code == 200 assert response.reason_phrase == "OK" -@threadpool def test_get_invalid_url(server): with pytest.raises(httpx.InvalidURL): httpx.get("invalid://example.org") diff --git a/tests/test_concurrency.py b/tests/test_concurrency.py index 870a592d..2fa01615 100644 --- a/tests/test_concurrency.py +++ b/tests/test_concurrency.py @@ -20,12 +20,17 @@ async def test_start_tls_on_socket_stream(https_server): timeout = TimeoutConfig(5) stream = await backend.connect("127.0.0.1", 8001, None, timeout) - assert stream.is_connection_dropped() is False - assert stream.stream_writer.get_extra_info("cipher", default=None) is None - stream = await backend.start_tls(stream, "127.0.0.1", ctx, timeout) - assert stream.is_connection_dropped() is False - assert stream.stream_writer.get_extra_info("cipher", default=None) is not None + try: + assert stream.is_connection_dropped() is False + assert stream.stream_writer.get_extra_info("cipher", default=None) is None - await stream.write(b"GET / HTTP/1.1\r\n\r\n") - assert (await stream.read(8192, timeout)).startswith(b"HTTP/1.1 200 OK\r\n") + stream = await backend.start_tls(stream, "127.0.0.1", ctx, timeout) + assert stream.is_connection_dropped() is False + assert stream.stream_writer.get_extra_info("cipher", default=None) is not None + + await stream.write(b"GET / HTTP/1.1\r\n\r\n") + assert (await stream.read(8192, timeout)).startswith(b"HTTP/1.1 200 OK\r\n") + + finally: + await stream.close()