]> git.ipfire.org Git - thirdparty/httpx.git/commitdiff
Serve test server in thread (#292)
authorFlorimond Manca <florimond.manca@gmail.com>
Thu, 29 Aug 2019 21:17:14 +0000 (23:17 +0200)
committerGitHub <noreply@github.com>
Thu, 29 Aug 2019 21:17:14 +0000 (23:17 +0200)
httpx/concurrency/asyncio.py
tests/client/test_client.py
tests/conftest.py
tests/dispatch/test_connection_pools.py
tests/dispatch/test_connections.py
tests/test_api.py
tests/test_concurrency.py

index 8fc625f1c06b82074e86582ec03c9a2f188248b7..6af8f9c5bbf20a6f0d9dfec3be3350f4b2331438 100644 (file)
@@ -113,7 +113,7 @@ class Stream(BaseStream):
                 )
                 break
             except asyncio.TimeoutError:
-                # We check our flag at the possible moment, in order to
+                # We check our flag at the first possible moment, in order to
                 # allow us to suppress write timeouts, if we've since
                 # switched over to read-timeout mode.
                 should_raise = flag is None or flag.raise_on_write_timeout
index ac558f5b31fe5c689b765be4f515796e898d0565..52d99e3b0c8c29256f21057810b141d3eb5a282a 100644 (file)
@@ -1,29 +1,8 @@
-import asyncio
-import functools
-
 import pytest
 
 import httpx
 
 
-def threadpool(func):
-    """
-    Our sync tests should run in separate thread to the uvicorn server.
-    """
-
-    @functools.wraps(func)
-    async def wrapped(*args, **kwargs):
-        nonlocal func
-
-        loop = asyncio.get_event_loop()
-        if kwargs:
-            func = functools.partial(func, **kwargs)
-        await loop.run_in_executor(None, func, *args)
-
-    return pytest.mark.asyncio(wrapped)
-
-
-@threadpool
 def test_get(server):
     url = "http://127.0.0.1:8000/"
     with httpx.Client() as http:
@@ -40,7 +19,6 @@ def test_get(server):
     assert repr(response) == "<Response [200 OK]>"
 
 
-@threadpool
 def test_post(server):
     with httpx.Client() as http:
         response = http.post("http://127.0.0.1:8000/", data=b"Hello, world!")
@@ -48,7 +26,6 @@ def test_post(server):
     assert response.reason_phrase == "OK"
 
 
-@threadpool
 def test_post_json(server):
     with httpx.Client() as http:
         response = http.post("http://127.0.0.1:8000/", json={"text": "Hello, world!"})
@@ -56,7 +33,6 @@ def test_post_json(server):
     assert response.reason_phrase == "OK"
 
 
-@threadpool
 def test_stream_response(server):
     with httpx.Client() as http:
         response = http.get("http://127.0.0.1:8000/", stream=True)
@@ -65,7 +41,6 @@ def test_stream_response(server):
     assert content == b"Hello, world!"
 
 
-@threadpool
 def test_stream_iterator(server):
     with httpx.Client() as http:
         response = http.get("http://127.0.0.1:8000/", stream=True)
@@ -76,7 +51,6 @@ def test_stream_iterator(server):
     assert body == b"Hello, world!"
 
 
-@threadpool
 def test_raw_iterator(server):
     with httpx.Client() as http:
         response = http.get("http://127.0.0.1:8000/", stream=True)
@@ -88,7 +62,6 @@ def test_raw_iterator(server):
     response.close()  # TODO: should Response be available as context managers?
 
 
-@threadpool
 def test_raise_for_status(server):
     with httpx.Client() as client:
         for status_code in (200, 400, 404, 500, 505):
@@ -103,7 +76,6 @@ def test_raise_for_status(server):
                 assert response.raise_for_status() is None
 
 
-@threadpool
 def test_options(server):
     with httpx.Client() as http:
         response = http.options("http://127.0.0.1:8000/")
@@ -111,7 +83,6 @@ def test_options(server):
     assert response.reason_phrase == "OK"
 
 
-@threadpool
 def test_head(server):
     with httpx.Client() as http:
         response = http.head("http://127.0.0.1:8000/")
@@ -119,7 +90,6 @@ def test_head(server):
     assert response.reason_phrase == "OK"
 
 
-@threadpool
 def test_put(server):
     with httpx.Client() as http:
         response = http.put("http://127.0.0.1:8000/", data=b"Hello, world!")
@@ -127,7 +97,6 @@ def test_put(server):
     assert response.reason_phrase == "OK"
 
 
-@threadpool
 def test_patch(server):
     with httpx.Client() as http:
         response = http.patch("http://127.0.0.1:8000/", data=b"Hello, world!")
@@ -135,7 +104,6 @@ def test_patch(server):
     assert response.reason_phrase == "OK"
 
 
-@threadpool
 def test_delete(server):
     with httpx.Client() as http:
         response = http.delete("http://127.0.0.1:8000/")
@@ -143,7 +111,6 @@ def test_delete(server):
     assert response.reason_phrase == "OK"
 
 
-@threadpool
 def test_base_url(server):
     base_url = "http://127.0.0.1:8000/"
     with httpx.Client(base_url=base_url) as http:
index 2974da416b4c3f78982d2b8bdcbb25b3060a1bd7..ec75e60093460f44944143f09e7144eaabd206eb 100644 (file)
@@ -1,4 +1,6 @@
 import asyncio
+import threading
+import time
 
 import pytest
 import trustme
@@ -101,76 +103,119 @@ class CAWithPKEncryption(trustme.CA):
         )
 
 
-@pytest.fixture
+SERVER_SCOPE = "session"
+
+
+@pytest.fixture(scope=SERVER_SCOPE)
 def example_cert():
     ca = CAWithPKEncryption()
     ca.issue_cert("example.org")
     return ca
 
 
-@pytest.fixture
+@pytest.fixture(scope=SERVER_SCOPE)
 def cert_pem_file(example_cert):
     with example_cert.cert_pem.tempfile() as tmp:
         yield tmp
 
 
-@pytest.fixture
+@pytest.fixture(scope=SERVER_SCOPE)
 def cert_private_key_file(example_cert):
     with example_cert.private_key_pem.tempfile() as tmp:
         yield tmp
 
 
-@pytest.fixture
+@pytest.fixture(scope=SERVER_SCOPE)
 def cert_encrypted_private_key_file(example_cert):
     with example_cert.encrypted_private_key_pem.tempfile() as tmp:
         yield tmp
 
 
+class TestServer(Server):
+    def install_signal_handlers(self) -> None:
+        # Disable the default installation of handlers for signals such as SIGTERM,
+        # because it can only be done in the main thread.
+        pass
+
+    async def serve(self, sockets=None):
+        self.restart_requested = asyncio.Event()
+
+        loop = asyncio.get_event_loop()
+        tasks = {
+            loop.create_task(super().serve(sockets=sockets)),
+            loop.create_task(self.watch_restarts()),
+        }
+
+        await asyncio.wait(tasks)
+
+    async def restart(self) -> None:
+        # Ensure we are in an asyncio environment.
+        assert asyncio.get_event_loop() is not None
+        # This may be called from a different thread than the one the server is
+        # running on. For this reason, we use an event to coordinate with the server
+        # instead of calling shutdown()/startup() directly.
+        self.restart_requested.set()
+        self.started = False
+        while not self.started:
+            await asyncio.sleep(0.5)
+
+    async def watch_restarts(self):
+        while True:
+            if self.should_exit:
+                return
+
+            try:
+                await asyncio.wait_for(self.restart_requested.wait(), timeout=0.1)
+            except asyncio.TimeoutError:
+                continue
+
+            self.restart_requested.clear()
+            await self.shutdown()
+            await self.startup()
+
+
 @pytest.fixture
-async def server():
-    config = Config(app=app, lifespan="off")
-    server = Server(config=config)
-    task = asyncio.ensure_future(server.serve())
+def restart(backend):
+    """Restart the running server from an async test function.
+
+    This fixture deals with possible differences between the environment of the
+    test function and that of the server.
+    """
+
+    async def restart(server):
+        await backend.run_in_threadpool(AsyncioBackend().run, server.restart)
+
+    return restart
+
+
+def serve_in_thread(server: Server):
+    thread = threading.Thread(target=server.run)
+    thread.start()
     try:
         while not server.started:
-            await asyncio.sleep(0.0001)
+            time.sleep(1e-3)
         yield server
     finally:
         server.should_exit = True
-        await task
+        thread.join()
 
 
-@pytest.fixture
-async def https_server(cert_pem_file, cert_private_key_file):
+@pytest.fixture(scope=SERVER_SCOPE)
+def server():
+    config = Config(app=app, lifespan="off", loop="asyncio")
+    server = TestServer(config=config)
+    yield from serve_in_thread(server)
+
+
+@pytest.fixture(scope=SERVER_SCOPE)
+def https_server(cert_pem_file, cert_private_key_file):
     config = Config(
         app=app,
         lifespan="off",
         ssl_certfile=cert_pem_file,
         ssl_keyfile=cert_private_key_file,
         port=8001,
+        loop="asyncio",
     )
-    server = Server(config=config)
-    task = asyncio.ensure_future(server.serve())
-    try:
-        while not server.started:
-            await asyncio.sleep(0.0001)
-        yield server
-    finally:
-        server.should_exit = True
-        await task
-
-
-@pytest.fixture
-def restart(backend):
-    async def asyncio_restart(server):
-        await server.shutdown()
-        await server.startup()
-
-    if isinstance(backend, AsyncioBackend):
-        return asyncio_restart
-
-    # The uvicorn server runs under asyncio, so we will need to figure out
-    # how to restart it under a different I/O library.
-    # This will most likely require running `asyncio_restart` in the threadpool,
-    # but that might not be sufficient.
-    raise NotImplementedError
+    server = TestServer(config=config)
+    yield from serve_in_thread(server)
index 4dd8424de9dcf75c103d524c427186ad25b0a54f..8f0ba70c93091de2cad1f801b607554d88263017 100644 (file)
@@ -134,7 +134,7 @@ async def test_keepalive_connection_closed_by_server_is_reestablished(
         response = await http.request("GET", "http://127.0.0.1:8000/")
         await response.read()
 
-        # shutdown the server to close the keep-alive connection
+        # Shutdown the server to close the keep-alive connection
         await restart(server)
 
         response = await http.request("GET", "http://127.0.0.1:8000/")
@@ -154,7 +154,7 @@ async def test_keepalive_http2_connection_closed_by_server_is_reestablished(
         response = await http.request("GET", "http://127.0.0.1:8000/")
         await response.read()
 
-        # shutdown the server to close the keep-alive connection
+        # Shutdown the server to close the keep-alive connection
         await restart(server)
 
         response = await http.request("GET", "http://127.0.0.1:8000/")
index 59e9bb1412ba9958cc028fa5217aa2347550f00f..0aa7741287ba49c81359fbd4e7fd2660bc8ab054 100644 (file)
@@ -2,40 +2,42 @@ from httpx import HTTPConnection
 
 
 async def test_get(server, backend):
-    conn = HTTPConnection(origin="http://127.0.0.1:8000/", backend=backend)
-    response = await conn.request("GET", "http://127.0.0.1:8000/")
-    await response.read()
-    assert response.status_code == 200
-    assert response.content == b"Hello, world!"
+    async with HTTPConnection(origin="http://127.0.0.1:8000/", backend=backend) as conn:
+        response = await conn.request("GET", "http://127.0.0.1:8000/")
+        await response.read()
+        assert response.status_code == 200
+        assert response.content == b"Hello, world!"
 
 
 async def test_post(server, backend):
-    conn = HTTPConnection(origin="http://127.0.0.1:8000/", backend=backend)
-    response = await conn.request(
-        "GET", "http://127.0.0.1:8000/", data=b"Hello, world!"
-    )
-    assert response.status_code == 200
+    async with HTTPConnection(origin="http://127.0.0.1:8000/", backend=backend) as conn:
+        response = await conn.request(
+            "GET", "http://127.0.0.1:8000/", data=b"Hello, world!"
+        )
+        assert response.status_code == 200
 
 
 async def test_https_get_with_ssl_defaults(https_server, backend):
     """
     An HTTPS request, with default SSL configuration set on the client.
     """
-    conn = HTTPConnection(
+    async with HTTPConnection(
         origin="https://127.0.0.1:8001/", verify=False, backend=backend
-    )
-    response = await conn.request("GET", "https://127.0.0.1:8001/")
-    await response.read()
-    assert response.status_code == 200
-    assert response.content == b"Hello, world!"
+    ) as conn:
+        response = await conn.request("GET", "https://127.0.0.1:8001/")
+        await response.read()
+        assert response.status_code == 200
+        assert response.content == b"Hello, world!"
 
 
 async def test_https_get_with_sll_overrides(https_server, backend):
     """
     An HTTPS request, with SSL configuration set on the request.
     """
-    conn = HTTPConnection(origin="https://127.0.0.1:8001/", backend=backend)
-    response = await conn.request("GET", "https://127.0.0.1:8001/", verify=False)
-    await response.read()
-    assert response.status_code == 200
-    assert response.content == b"Hello, world!"
+    async with HTTPConnection(
+        origin="https://127.0.0.1:8001/", backend=backend
+    ) as conn:
+        response = await conn.request("GET", "https://127.0.0.1:8001/", verify=False)
+        await response.read()
+        assert response.status_code == 200
+        assert response.content == b"Hello, world!"
index de4267c713002627085c55340fb89f9164bcb8ff..cb9eb20c66e656e54025ab88fd5d60e8717f4667 100644 (file)
@@ -1,29 +1,8 @@
-import asyncio
-import functools
-
 import pytest
 
 import httpx
 
 
-def threadpool(func):
-    """
-    Our sync tests should run in separate thread to the uvicorn server.
-    """
-
-    @functools.wraps(func)
-    async def wrapped(*args, **kwargs):
-        nonlocal func
-
-        loop = asyncio.get_event_loop()
-        if kwargs:
-            func = functools.partial(func, **kwargs)
-        await loop.run_in_executor(None, func, *args)
-
-    return pytest.mark.asyncio(wrapped)
-
-
-@threadpool
 def test_get(server):
     response = httpx.get("http://127.0.0.1:8000/")
     assert response.status_code == 200
@@ -32,14 +11,12 @@ def test_get(server):
     assert response.http_version == "HTTP/1.1"
 
 
-@threadpool
 def test_post(server):
     response = httpx.post("http://127.0.0.1:8000/", data=b"Hello, world!")
     assert response.status_code == 200
     assert response.reason_phrase == "OK"
 
 
-@threadpool
 def test_post_byte_iterator(server):
     def data():
         yield b"Hello"
@@ -51,42 +28,36 @@ def test_post_byte_iterator(server):
     assert response.reason_phrase == "OK"
 
 
-@threadpool
 def test_options(server):
     response = httpx.options("http://127.0.0.1:8000/")
     assert response.status_code == 200
     assert response.reason_phrase == "OK"
 
 
-@threadpool
 def test_head(server):
     response = httpx.head("http://127.0.0.1:8000/")
     assert response.status_code == 200
     assert response.reason_phrase == "OK"
 
 
-@threadpool
 def test_put(server):
     response = httpx.put("http://127.0.0.1:8000/", data=b"Hello, world!")
     assert response.status_code == 200
     assert response.reason_phrase == "OK"
 
 
-@threadpool
 def test_patch(server):
     response = httpx.patch("http://127.0.0.1:8000/", data=b"Hello, world!")
     assert response.status_code == 200
     assert response.reason_phrase == "OK"
 
 
-@threadpool
 def test_delete(server):
     response = httpx.delete("http://127.0.0.1:8000/")
     assert response.status_code == 200
     assert response.reason_phrase == "OK"
 
 
-@threadpool
 def test_get_invalid_url(server):
     with pytest.raises(httpx.InvalidURL):
         httpx.get("invalid://example.org")
index 870a592d01b2d00543f3ee53f0055dd6bb77dc1b..2fa0161510dcbd1b7a9712eeb730d91ab8809f5a 100644 (file)
@@ -20,12 +20,17 @@ async def test_start_tls_on_socket_stream(https_server):
     timeout = TimeoutConfig(5)
 
     stream = await backend.connect("127.0.0.1", 8001, None, timeout)
-    assert stream.is_connection_dropped() is False
-    assert stream.stream_writer.get_extra_info("cipher", default=None) is None
 
-    stream = await backend.start_tls(stream, "127.0.0.1", ctx, timeout)
-    assert stream.is_connection_dropped() is False
-    assert stream.stream_writer.get_extra_info("cipher", default=None) is not None
+    try:
+        assert stream.is_connection_dropped() is False
+        assert stream.stream_writer.get_extra_info("cipher", default=None) is None
 
-    await stream.write(b"GET / HTTP/1.1\r\n\r\n")
-    assert (await stream.read(8192, timeout)).startswith(b"HTTP/1.1 200 OK\r\n")
+        stream = await backend.start_tls(stream, "127.0.0.1", ctx, timeout)
+        assert stream.is_connection_dropped() is False
+        assert stream.stream_writer.get_extra_info("cipher", default=None) is not None
+
+        await stream.write(b"GET / HTTP/1.1\r\n\r\n")
+        assert (await stream.read(8192, timeout)).startswith(b"HTTP/1.1 200 OK\r\n")
+
+    finally:
+        await stream.close()