import asyncio
-import functools
import ssl
import typing
ssl_monkey_patch()
SSL_MONKEY_PATCH_APPLIED = True
- @property
- def loop(self) -> asyncio.AbstractEventLoop:
- if not hasattr(self, "_loop"):
- try:
- self._loop = asyncio.get_event_loop()
- except RuntimeError:
- self._loop = asyncio.new_event_loop()
- return self._loop
-
async def open_tcp_stream(
self,
hostname: str,
loop = asyncio.get_event_loop()
return loop.time()
- async def run_in_threadpool(
- self, func: typing.Callable, *args: typing.Any, **kwargs: typing.Any
- ) -> typing.Any:
- if kwargs:
- # loop.run_in_executor doesn't accept 'kwargs', so bind them in here
- func = functools.partial(func, **kwargs)
- return await self.loop.run_in_executor(None, func, *args)
-
- def run(
- self, coroutine: typing.Callable, *args: typing.Any, **kwargs: typing.Any
- ) -> typing.Any:
- loop = self.loop
- if loop.is_running():
- self._loop = asyncio.new_event_loop()
- try:
- return self.loop.run_until_complete(coroutine(*args, **kwargs))
- finally:
- self._loop = loop
-
def create_semaphore(self, max_value: int, exc_class: type) -> BaseSemaphore:
return Semaphore(max_value, exc_class)
def time(self) -> float:
return self.backend.time()
- async def run_in_threadpool(
- self, func: typing.Callable, *args: typing.Any, **kwargs: typing.Any
- ) -> typing.Any:
- return await self.backend.run_in_threadpool(func, *args, **kwargs)
-
def create_semaphore(self, max_value: int, exc_class: type) -> BaseSemaphore:
return self.backend.create_semaphore(max_value, exc_class)
def time(self) -> float:
raise NotImplementedError() # pragma: no cover
- async def run_in_threadpool(
- self, func: typing.Callable, *args: typing.Any, **kwargs: typing.Any
- ) -> typing.Any:
- raise NotImplementedError() # pragma: no cover
-
- def run(
- self, coroutine: typing.Callable, *args: typing.Any, **kwargs: typing.Any
- ) -> typing.Any:
- raise NotImplementedError() # pragma: no cover
-
def create_semaphore(self, max_value: int, exc_class: type) -> BaseSemaphore:
raise NotImplementedError() # pragma: no cover
-import functools
import ssl
import typing
raise ConnectTimeout()
- async def run_in_threadpool(
- self, func: typing.Callable, *args: typing.Any, **kwargs: typing.Any
- ) -> typing.Any:
- return await trio.to_thread.run_sync(
- functools.partial(func, **kwargs) if kwargs else func, *args
- )
-
- def run(
- self, coroutine: typing.Callable, *args: typing.Any, **kwargs: typing.Any
- ) -> typing.Any:
- return trio.run(
- functools.partial(coroutine, **kwargs) if kwargs else coroutine, *args
- )
-
def time(self) -> float:
return trio.current_time()
"""
import asyncio
-import functools
-import typing
+import sniffio
import trio
-from httpx.backends.asyncio import AsyncioBackend
-from httpx.backends.auto import AutoBackend
-from httpx.backends.trio import TrioBackend
-
-@functools.singledispatch
-async def run_concurrently(backend, *coroutines: typing.Callable[[], typing.Awaitable]):
- raise NotImplementedError # pragma: no cover
-
-
-@run_concurrently.register(AutoBackend)
-async def _run_concurrently_auto(backend, *coroutines):
- await run_concurrently(backend.backend, *coroutines)
-
-
-@run_concurrently.register(AsyncioBackend)
-async def _run_concurrently_asyncio(backend, *coroutines):
- coros = (coroutine() for coroutine in coroutines)
- await asyncio.gather(*coros)
-
-
-@run_concurrently.register(TrioBackend)
-async def _run_concurrently_trio(backend, *coroutines):
- async with trio.open_nursery() as nursery:
- for coroutine in coroutines:
- nursery.start_soon(coroutine)
-
-
-@functools.singledispatch
-def get_cipher(backend, stream):
- raise NotImplementedError # pragma: no cover
-
-
-@get_cipher.register(AutoBackend)
-def _get_cipher_auto(backend, stream):
- return get_cipher(backend.backend, stream)
-
-
-@get_cipher.register(AsyncioBackend)
-def _get_cipher_asyncio(backend, stream):
- return stream.stream_writer.get_extra_info("cipher", default=None)
-
-
-@get_cipher.register(TrioBackend)
-def get_trio_cipher(backend, stream):
- return stream.stream.cipher() if isinstance(stream.stream, trio.SSLStream) else None
+async def sleep(seconds: float):
+ if sniffio.current_async_library() == "trio":
+ await trio.sleep(seconds)
+ else:
+ await asyncio.sleep(seconds)
+
+
+async def run_concurrently(*coroutines):
+ if sniffio.current_async_library() == "trio":
+ async with trio.open_nursery() as nursery:
+ for coroutine in coroutines:
+ nursery.start_soon(coroutine)
+ else:
+ coros = (coroutine() for coroutine in coroutines)
+ await asyncio.gather(*coros)
+
+
+def get_cipher(stream):
+ if sniffio.current_async_library() == "trio":
+ return (
+ stream.stream.cipher()
+ if isinstance(stream.stream, trio.SSLStream)
+ else None
+ )
+ else:
+ return stream.stream_writer.get_extra_info("cipher", default=None)
from uvicorn.main import Server
from httpx import URL
-from httpx.backends.asyncio import AsyncioBackend
-from httpx.backends.base import lookup_backend
+from tests.concurrency import sleep
ENVIRONMENT_VARIABLES = {
"SSL_CERT_FILE",
delay_ms = float(delay_ms_str)
except ValueError:
delay_ms = 100
- await asyncio.sleep(delay_ms / 1000.0)
+ await sleep(delay_ms / 1000.0)
await send(
{
"type": "http.response.start",
await asyncio.wait(tasks)
async def restart(self) -> None:
- # Ensure we are in an asyncio environment.
- assert asyncio.get_event_loop() is not None
- # This may be called from a different thread than the one the server is
- # running on. For this reason, we use an event to coordinate with the server
- # instead of calling shutdown()/startup() directly.
- self.restart_requested.set()
+ # This coroutine may be called from a different thread than the one the
+ # server is running on, and from an async environment that's not asyncio.
+ # For this reason, we use an event to coordinate with the server
+ # instead of calling shutdown()/startup() directly, and should not make
+ # any asyncio-specific operations.
self.started = False
+ self.restart_requested.set()
while not self.started:
- await asyncio.sleep(0.5)
+ await sleep(0.2)
async def watch_restarts(self):
while True:
await self.startup()
-@pytest.fixture
-def restart():
- """Restart the running server from an async test function.
-
- This fixture deals with possible differences between the environment of the
- test function and that of the server.
- """
- asyncio_backend = AsyncioBackend()
-
- async def restart(server):
- backend = lookup_backend()
- await backend.run_in_threadpool(asyncio_backend.run, server.restart)
-
- return restart
-
-
def serve_in_thread(server: Server):
thread = threading.Thread(target=server.run)
thread.start()
@pytest.mark.usefixtures("async_environment")
-async def test_keepalive_connection_closed_by_server_is_reestablished(server, restart):
+async def test_keepalive_connection_closed_by_server_is_reestablished(server):
"""
Upon keep-alive connection closed by remote a new connection
should be reestablished.
await response.aread()
# Shutdown the server to close the keep-alive connection
- await restart(server)
+ await server.restart()
response = await http.request("GET", server.url)
await response.aread()
@pytest.mark.usefixtures("async_environment")
-async def test_keepalive_http2_connection_closed_by_server_is_reestablished(
- server, restart
-):
+async def test_keepalive_http2_connection_closed_by_server_is_reestablished(server):
"""
Upon keep-alive connection closed by remote a new connection
should be reestablished.
await response.aread()
# Shutdown the server to close the keep-alive connection
- await restart(server)
+ await server.restart()
response = await http.request("GET", server.url)
await response.aread()
@pytest.mark.usefixtures("async_environment")
-async def test_connection_closed_free_semaphore_on_acquire(server, restart):
+async def test_connection_closed_free_semaphore_on_acquire(server):
"""
Verify that max_connections semaphore is released
properly on a disconnected connection.
await response.aread()
# Close the connection so we're forced to recycle it
- await restart(server)
+ await server.restart()
response = await http.request("GET", server.url)
assert response.status_code == 200
try:
assert stream.is_connection_dropped() is False
- assert get_cipher(backend, stream) is None
+ assert get_cipher(stream) is None
stream = await stream.start_tls(https_server.url.host, ctx, timeout)
assert stream.is_connection_dropped() is False
- assert get_cipher(backend, stream) is not None
+ assert get_cipher(stream) is not None
await stream.write(b"GET / HTTP/1.1\r\n\r\n", timeout)
try:
assert stream.is_connection_dropped() is False
- assert get_cipher(backend, stream) is None
+ assert get_cipher(stream) is None
stream = await stream.start_tls(https_uds_server.url.host, ctx, timeout)
assert stream.is_connection_dropped() is False
- assert get_cipher(backend, stream) is not None
+ assert get_cipher(stream) is not None
await stream.write(b"GET / HTTP/1.1\r\n\r\n", timeout)
try:
await stream.write(b"GET / HTTP/1.1\r\n\r\n", timeout)
await run_concurrently(
- backend, lambda: stream.read(10, timeout), lambda: stream.read(10, timeout)
+ lambda: stream.read(10, timeout), lambda: stream.read(10, timeout)
)
finally:
await stream.close()