]> git.ipfire.org Git - thirdparty/httpx.git/commitdiff
Drop `run` and `run_in_threadpool` (#710)
authorTom Christie <tom@tomchristie.com>
Mon, 6 Jan 2020 11:14:43 +0000 (11:14 +0000)
committerGitHub <noreply@github.com>
Mon, 6 Jan 2020 11:14:43 +0000 (11:14 +0000)
* Drop run and run_in_threadpool

* Fix server restart errors

* Re-introduce 'sleep' as a concurrency test utility

* Simpler test concurrency utils

Co-authored-by: Florimond Manca <florimond.manca@gmail.com>
httpx/backends/asyncio.py
httpx/backends/auto.py
httpx/backends/base.py
httpx/backends/trio.py
tests/concurrency.py
tests/conftest.py
tests/dispatch/test_connection_pools.py
tests/test_concurrency.py

index 854618e119d0b01fabeb67d422245daaabf60c2a..74d2b0d4f854424131c49c28584811306357934e 100644 (file)
@@ -1,5 +1,4 @@
 import asyncio
-import functools
 import ssl
 import typing
 
@@ -182,15 +181,6 @@ class AsyncioBackend(ConcurrencyBackend):
             ssl_monkey_patch()
         SSL_MONKEY_PATCH_APPLIED = True
 
-    @property
-    def loop(self) -> asyncio.AbstractEventLoop:
-        if not hasattr(self, "_loop"):
-            try:
-                self._loop = asyncio.get_event_loop()
-            except RuntimeError:
-                self._loop = asyncio.new_event_loop()
-        return self._loop
-
     async def open_tcp_stream(
         self,
         hostname: str,
@@ -233,25 +223,6 @@ class AsyncioBackend(ConcurrencyBackend):
         loop = asyncio.get_event_loop()
         return loop.time()
 
-    async def run_in_threadpool(
-        self, func: typing.Callable, *args: typing.Any, **kwargs: typing.Any
-    ) -> typing.Any:
-        if kwargs:
-            # loop.run_in_executor doesn't accept 'kwargs', so bind them in here
-            func = functools.partial(func, **kwargs)
-        return await self.loop.run_in_executor(None, func, *args)
-
-    def run(
-        self, coroutine: typing.Callable, *args: typing.Any, **kwargs: typing.Any
-    ) -> typing.Any:
-        loop = self.loop
-        if loop.is_running():
-            self._loop = asyncio.new_event_loop()
-        try:
-            return self.loop.run_until_complete(coroutine(*args, **kwargs))
-        finally:
-            self._loop = loop
-
     def create_semaphore(self, max_value: int, exc_class: type) -> BaseSemaphore:
         return Semaphore(max_value, exc_class)
 
index 0c0d2d88acf6cd4fe3247853a8da7412124e827d..7a8c597822104f037df0262c283aae4c0fec6ad5 100644 (file)
@@ -44,11 +44,6 @@ class AutoBackend(ConcurrencyBackend):
     def time(self) -> float:
         return self.backend.time()
 
-    async def run_in_threadpool(
-        self, func: typing.Callable, *args: typing.Any, **kwargs: typing.Any
-    ) -> typing.Any:
-        return await self.backend.run_in_threadpool(func, *args, **kwargs)
-
     def create_semaphore(self, max_value: int, exc_class: type) -> BaseSemaphore:
         return self.backend.create_semaphore(max_value, exc_class)
 
index e55b6363d353f7901e2457f8e98dee004a3f5a1a..964d09449f098d83178b401b508251398a128cd2 100644 (file)
@@ -114,16 +114,6 @@ class ConcurrencyBackend:
     def time(self) -> float:
         raise NotImplementedError()  # pragma: no cover
 
-    async def run_in_threadpool(
-        self, func: typing.Callable, *args: typing.Any, **kwargs: typing.Any
-    ) -> typing.Any:
-        raise NotImplementedError()  # pragma: no cover
-
-    def run(
-        self, coroutine: typing.Callable, *args: typing.Any, **kwargs: typing.Any
-    ) -> typing.Any:
-        raise NotImplementedError()  # pragma: no cover
-
     def create_semaphore(self, max_value: int, exc_class: type) -> BaseSemaphore:
         raise NotImplementedError()  # pragma: no cover
 
index 6e79dc8b5b08dca9c87dc63bb11283d3ad7ef115..979aa450b7899f10ec7f09e6e05f6ef91f51a32c 100644 (file)
@@ -1,4 +1,3 @@
-import functools
 import ssl
 import typing
 
@@ -120,20 +119,6 @@ class TrioBackend(ConcurrencyBackend):
 
         raise ConnectTimeout()
 
-    async def run_in_threadpool(
-        self, func: typing.Callable, *args: typing.Any, **kwargs: typing.Any
-    ) -> typing.Any:
-        return await trio.to_thread.run_sync(
-            functools.partial(func, **kwargs) if kwargs else func, *args
-        )
-
-    def run(
-        self, coroutine: typing.Callable, *args: typing.Any, **kwargs: typing.Any
-    ) -> typing.Any:
-        return trio.run(
-            functools.partial(coroutine, **kwargs) if kwargs else coroutine, *args
-        )
-
     def time(self) -> float:
         return trio.current_time()
 
index 361e29b1cbe3449885a63ba1e37c18c63eb76ace..83a6a0609745004c36e98f075eb0549fd4e927c9 100644 (file)
@@ -4,54 +4,34 @@ required as part of the ConcurrencyBackend API.
 """
 
 import asyncio
-import functools
-import typing
 
+import sniffio
 import trio
 
-from httpx.backends.asyncio import AsyncioBackend
-from httpx.backends.auto import AutoBackend
-from httpx.backends.trio import TrioBackend
 
-
-@functools.singledispatch
-async def run_concurrently(backend, *coroutines: typing.Callable[[], typing.Awaitable]):
-    raise NotImplementedError  # pragma: no cover
-
-
-@run_concurrently.register(AutoBackend)
-async def _run_concurrently_auto(backend, *coroutines):
-    await run_concurrently(backend.backend, *coroutines)
-
-
-@run_concurrently.register(AsyncioBackend)
-async def _run_concurrently_asyncio(backend, *coroutines):
-    coros = (coroutine() for coroutine in coroutines)
-    await asyncio.gather(*coros)
-
-
-@run_concurrently.register(TrioBackend)
-async def _run_concurrently_trio(backend, *coroutines):
-    async with trio.open_nursery() as nursery:
-        for coroutine in coroutines:
-            nursery.start_soon(coroutine)
-
-
-@functools.singledispatch
-def get_cipher(backend, stream):
-    raise NotImplementedError  # pragma: no cover
-
-
-@get_cipher.register(AutoBackend)
-def _get_cipher_auto(backend, stream):
-    return get_cipher(backend.backend, stream)
-
-
-@get_cipher.register(AsyncioBackend)
-def _get_cipher_asyncio(backend, stream):
-    return stream.stream_writer.get_extra_info("cipher", default=None)
-
-
-@get_cipher.register(TrioBackend)
-def get_trio_cipher(backend, stream):
-    return stream.stream.cipher() if isinstance(stream.stream, trio.SSLStream) else None
+async def sleep(seconds: float):
+    if sniffio.current_async_library() == "trio":
+        await trio.sleep(seconds)
+    else:
+        await asyncio.sleep(seconds)
+
+
+async def run_concurrently(*coroutines):
+    if sniffio.current_async_library() == "trio":
+        async with trio.open_nursery() as nursery:
+            for coroutine in coroutines:
+                nursery.start_soon(coroutine)
+    else:
+        coros = (coroutine() for coroutine in coroutines)
+        await asyncio.gather(*coros)
+
+
+def get_cipher(stream):
+    if sniffio.current_async_library() == "trio":
+        return (
+            stream.stream.cipher()
+            if isinstance(stream.stream, trio.SSLStream)
+            else None
+        )
+    else:
+        return stream.stream_writer.get_extra_info("cipher", default=None)
index d6f07b16f06339a3c0e51c27b2e8b53ff4fefeeb..3edb46681e94c1e876d2d8a8b053489153e13b75 100644 (file)
@@ -18,8 +18,7 @@ from uvicorn.config import Config
 from uvicorn.main import Server
 
 from httpx import URL
-from httpx.backends.asyncio import AsyncioBackend
-from httpx.backends.base import lookup_backend
+from tests.concurrency import sleep
 
 ENVIRONMENT_VARIABLES = {
     "SSL_CERT_FILE",
@@ -108,7 +107,7 @@ async def slow_response(scope, receive, send):
         delay_ms = float(delay_ms_str)
     except ValueError:
         delay_ms = 100
-    await asyncio.sleep(delay_ms / 1000.0)
+    await sleep(delay_ms / 1000.0)
     await send(
         {
             "type": "http.response.start",
@@ -252,15 +251,15 @@ class TestServer(Server):
         await asyncio.wait(tasks)
 
     async def restart(self) -> None:
-        # Ensure we are in an asyncio environment.
-        assert asyncio.get_event_loop() is not None
-        # This may be called from a different thread than the one the server is
-        # running on. For this reason, we use an event to coordinate with the server
-        # instead of calling shutdown()/startup() directly.
-        self.restart_requested.set()
+        # This coroutine may be called from a different thread than the one the
+        # server is running on, and from an async environment that's not asyncio.
+        # For this reason, we use an event to coordinate with the server
+        # instead of calling shutdown()/startup() directly, and should not make
+        # any asyncio-specific operations.
         self.started = False
+        self.restart_requested.set()
         while not self.started:
-            await asyncio.sleep(0.5)
+            await sleep(0.2)
 
     async def watch_restarts(self):
         while True:
@@ -277,22 +276,6 @@ class TestServer(Server):
             await self.startup()
 
 
-@pytest.fixture
-def restart():
-    """Restart the running server from an async test function.
-
-    This fixture deals with possible differences between the environment of the
-    test function and that of the server.
-    """
-    asyncio_backend = AsyncioBackend()
-
-    async def restart(server):
-        backend = lookup_backend()
-        await backend.run_in_threadpool(asyncio_backend.run, server.restart)
-
-    return restart
-
-
 def serve_in_thread(server: Server):
     thread = threading.Thread(target=server.run)
     thread.start()
index 0ec12a7c6cc855b1096537c0d649724beeb9f508..78209585552208e14e3eafb232b9bca053cf0348 100644 (file)
@@ -167,7 +167,7 @@ async def test_premature_response_close(server):
 
 
 @pytest.mark.usefixtures("async_environment")
-async def test_keepalive_connection_closed_by_server_is_reestablished(server, restart):
+async def test_keepalive_connection_closed_by_server_is_reestablished(server):
     """
     Upon keep-alive connection closed by remote a new connection
     should be reestablished.
@@ -177,7 +177,7 @@ async def test_keepalive_connection_closed_by_server_is_reestablished(server, re
         await response.aread()
 
         # Shutdown the server to close the keep-alive connection
-        await restart(server)
+        await server.restart()
 
         response = await http.request("GET", server.url)
         await response.aread()
@@ -186,9 +186,7 @@ async def test_keepalive_connection_closed_by_server_is_reestablished(server, re
 
 
 @pytest.mark.usefixtures("async_environment")
-async def test_keepalive_http2_connection_closed_by_server_is_reestablished(
-    server, restart
-):
+async def test_keepalive_http2_connection_closed_by_server_is_reestablished(server):
     """
     Upon keep-alive connection closed by remote a new connection
     should be reestablished.
@@ -198,7 +196,7 @@ async def test_keepalive_http2_connection_closed_by_server_is_reestablished(
         await response.aread()
 
         # Shutdown the server to close the keep-alive connection
-        await restart(server)
+        await server.restart()
 
         response = await http.request("GET", server.url)
         await response.aread()
@@ -207,7 +205,7 @@ async def test_keepalive_http2_connection_closed_by_server_is_reestablished(
 
 
 @pytest.mark.usefixtures("async_environment")
-async def test_connection_closed_free_semaphore_on_acquire(server, restart):
+async def test_connection_closed_free_semaphore_on_acquire(server):
     """
     Verify that max_connections semaphore is released
     properly on a disconnected connection.
@@ -217,7 +215,7 @@ async def test_connection_closed_free_semaphore_on_acquire(server, restart):
         await response.aread()
 
         # Close the connection so we're forced to recycle it
-        await restart(server)
+        await server.restart()
 
         response = await http.request("GET", server.url)
         assert response.status_code == 200
index c7e157ebfb60d816f00889e93ba6b325df4a247e..3ed755209a61c47302f2add6073af79899e0945b 100644 (file)
@@ -41,11 +41,11 @@ async def test_start_tls_on_tcp_socket_stream(https_server):
 
     try:
         assert stream.is_connection_dropped() is False
-        assert get_cipher(backend, stream) is None
+        assert get_cipher(stream) is None
 
         stream = await stream.start_tls(https_server.url.host, ctx, timeout)
         assert stream.is_connection_dropped() is False
-        assert get_cipher(backend, stream) is not None
+        assert get_cipher(stream) is not None
 
         await stream.write(b"GET / HTTP/1.1\r\n\r\n", timeout)
 
@@ -68,11 +68,11 @@ async def test_start_tls_on_uds_socket_stream(https_uds_server):
 
     try:
         assert stream.is_connection_dropped() is False
-        assert get_cipher(backend, stream) is None
+        assert get_cipher(stream) is None
 
         stream = await stream.start_tls(https_uds_server.url.host, ctx, timeout)
         assert stream.is_connection_dropped() is False
-        assert get_cipher(backend, stream) is not None
+        assert get_cipher(stream) is not None
 
         await stream.write(b"GET / HTTP/1.1\r\n\r\n", timeout)
 
@@ -96,7 +96,7 @@ async def test_concurrent_read(server):
     try:
         await stream.write(b"GET / HTTP/1.1\r\n\r\n", timeout)
         await run_concurrently(
-            backend, lambda: stream.read(10, timeout), lambda: stream.read(10, timeout)
+            lambda: stream.read(10, timeout), lambda: stream.read(10, timeout)
         )
     finally:
         await stream.close()