]> git.ipfire.org Git - thirdparty/httpx.git/commitdiff
Add trio concurrency backend (#276)
authorFlorimond Manca <florimond.manca@gmail.com>
Sat, 21 Sep 2019 16:10:20 +0000 (18:10 +0200)
committerSeth Michael Larson <sethmichaellarson@gmail.com>
Sat, 21 Sep 2019 16:10:20 +0000 (11:10 -0500)
13 files changed:
httpx/client.py
httpx/concurrency/asyncio.py
httpx/concurrency/base.py
httpx/concurrency/trio.py [new file with mode: 0644]
httpx/dispatch/asgi.py
httpx/dispatch/http2.py
setup.cfg
setup.py
test-requirements.txt
tests/concurrency.py
tests/conftest.py
tests/dispatch/test_connection_pools.py
tests/test_timeouts.py

index 266d980c9bbe1ba13905929d39e1065ac75d54db..653e7d86cc28b0e0c8cea1c7a82a3dcb68c1445d 100644 (file)
@@ -81,7 +81,7 @@ class BaseClient:
             if param_count == 2:
                 dispatch = WSGIDispatch(app=app)
             else:
-                dispatch = ASGIDispatch(app=app)
+                dispatch = ASGIDispatch(app=app, backend=backend)
 
         self.trust_env = True if trust_env is None else trust_env
 
index 1a145bed9039e2ee1c149e9890cf232379bd9c91..696e32c148c263fcc83e3f0c9e202c5f54cc6c8a 100644 (file)
@@ -112,6 +112,20 @@ class TCPStream(BaseTCPStream):
                     raise WriteTimeout() from None
 
     def is_connection_dropped(self) -> bool:
+        # Counter-intuitively, what we really want to know here is whether the socket is
+        # *readable*, i.e. whether it would return immediately with empty bytes if we
+        # called `.recv()` on it, indicating that the other end has closed the socket.
+        # See: https://github.com/encode/httpx/pull/143#issuecomment-515181778
+        #
+        # As it turns out, asyncio checks for readability in the background
+        # (see: https://github.com/encode/httpx/pull/276#discussion_r322000402),
+        # so checking for EOF or readability here would yield the same result.
+        #
+        # At the cost of rigour, we check for EOF instead of readability because asyncio
+        # does not expose any public API to check for readability.
+        # (For a solution that uses private asyncio APIs, see:
+        # https://github.com/encode/httpx/pull/143#issuecomment-515202982)
+
         return self.stream_reader.at_eof()
 
     async def close(self) -> None:
index 63fb14323d2177970b91f533e1d013a82d5a72ad..fc784b30f8f823bebdbcc4d1b7030d88ae59f891 100644 (file)
@@ -187,3 +187,10 @@ class BaseBackgroundManager:
         traceback: TracebackType = None,
     ) -> None:
         raise NotImplementedError()  # pragma: no cover
+
+    async def close(self, exception: BaseException = None) -> None:
+        if exception is None:
+            await self.__aexit__(None, None, None)
+        else:
+            traceback = exception.__traceback__  # type: ignore
+            await self.__aexit__(type(exception), exception, traceback)
diff --git a/httpx/concurrency/trio.py b/httpx/concurrency/trio.py
new file mode 100644 (file)
index 0000000..14d7986
--- /dev/null
@@ -0,0 +1,255 @@
+import functools
+import math
+import ssl
+import typing
+from types import TracebackType
+
+import trio
+
+from ..config import PoolLimits, TimeoutConfig
+from ..exceptions import ConnectTimeout, PoolTimeout, ReadTimeout, WriteTimeout
+from .base import (
+    BaseBackgroundManager,
+    BaseEvent,
+    BasePoolSemaphore,
+    BaseQueue,
+    BaseTCPStream,
+    ConcurrencyBackend,
+    TimeoutFlag,
+)
+
+
+def _or_inf(value: typing.Optional[float]) -> float:
+    return value if value is not None else float("inf")
+
+
+class TCPStream(BaseTCPStream):
+    def __init__(
+        self,
+        stream: typing.Union[trio.SocketStream, trio.SSLStream],
+        timeout: TimeoutConfig,
+    ) -> None:
+        self.stream = stream
+        self.timeout = timeout
+        self.write_buffer = b""
+        self.write_lock = trio.Lock()
+
+    def get_http_version(self) -> str:
+        if not isinstance(self.stream, trio.SSLStream):
+            return "HTTP/1.1"
+
+        ident = self.stream.selected_alpn_protocol()
+        if ident is None:
+            return "HTTP/1.1"
+
+        return "HTTP/2" if ident == "h2" else "HTTP/1.1"
+
+    async def read(
+        self, n: int, timeout: TimeoutConfig = None, flag: TimeoutFlag = None
+    ) -> bytes:
+        if timeout is None:
+            timeout = self.timeout
+
+        while True:
+            # Check our flag at the first possible moment, and use a fine
+            # grained retry loop if we're not yet in read-timeout mode.
+            should_raise = flag is None or flag.raise_on_read_timeout
+            read_timeout = _or_inf(timeout.read_timeout if should_raise else 0.01)
+
+            with trio.move_on_after(read_timeout):
+                return await self.stream.receive_some(max_bytes=n)
+
+            if should_raise:
+                raise ReadTimeout() from None
+
+    def is_connection_dropped(self) -> bool:
+        # Adapted from: https://github.com/encode/httpx/pull/143#issuecomment-515202982
+        stream = self.stream
+
+        # Peek through any SSLStream wrappers to get the underlying SocketStream.
+        while hasattr(stream, "transport_stream"):
+            stream = stream.transport_stream
+        assert isinstance(stream, trio.SocketStream)
+
+        # Counter-intuitively, what we really want to know here is whether the socket is
+        # *readable*, i.e. whether it would return immediately with empty bytes if we
+        # called `.recv()` on it, indicating that the other end has closed the socket.
+        # See: https://github.com/encode/httpx/pull/143#issuecomment-515181778
+        return stream.socket.is_readable()
+
+    def write_no_block(self, data: bytes) -> None:
+        self.write_buffer += data
+
+    async def write(
+        self, data: bytes, timeout: TimeoutConfig = None, flag: TimeoutFlag = None
+    ) -> None:
+        if self.write_buffer:
+            previous_data = self.write_buffer
+            # Reset before recursive call, otherwise we'll go through
+            # this branch indefinitely.
+            self.write_buffer = b""
+            try:
+                await self.write(previous_data, timeout=timeout, flag=flag)
+            except WriteTimeout:
+                self.writer_buffer = previous_data
+                raise
+
+        if not data:
+            return
+
+        if timeout is None:
+            timeout = self.timeout
+
+        write_timeout = _or_inf(timeout.write_timeout)
+
+        while True:
+            with trio.move_on_after(write_timeout):
+                async with self.write_lock:
+                    await self.stream.send_all(data)
+                break
+            # We check our flag at the first possible moment, in order to
+            # allow us to suppress write timeouts, if we've since
+            # switched over to read-timeout mode.
+            should_raise = flag is None or flag.raise_on_write_timeout
+            if should_raise:
+                raise WriteTimeout() from None
+
+    async def close(self) -> None:
+        await self.stream.aclose()
+
+
+class PoolSemaphore(BasePoolSemaphore):
+    def __init__(self, pool_limits: PoolLimits):
+        self.pool_limits = pool_limits
+
+    @property
+    def semaphore(self) -> typing.Optional[trio.Semaphore]:
+        if not hasattr(self, "_semaphore"):
+            max_connections = self.pool_limits.hard_limit
+            if max_connections is None:
+                self._semaphore = None
+            else:
+                self._semaphore = trio.Semaphore(
+                    max_connections, max_value=max_connections
+                )
+        return self._semaphore
+
+    async def acquire(self) -> None:
+        if self.semaphore is None:
+            return
+
+        timeout = _or_inf(self.pool_limits.pool_timeout)
+
+        with trio.move_on_after(timeout):
+            await self.semaphore.acquire()
+            return
+
+        raise PoolTimeout()
+
+    def release(self) -> None:
+        if self.semaphore is None:
+            return
+
+        self.semaphore.release()
+
+
+class TrioBackend(ConcurrencyBackend):
+    async def open_tcp_stream(
+        self,
+        hostname: str,
+        port: int,
+        ssl_context: typing.Optional[ssl.SSLContext],
+        timeout: TimeoutConfig,
+    ) -> TCPStream:
+        connect_timeout = _or_inf(timeout.connect_timeout)
+
+        with trio.move_on_after(connect_timeout) as cancel_scope:
+            stream: trio.SocketStream = await trio.open_tcp_stream(hostname, port)
+            if ssl_context is not None:
+                stream = trio.SSLStream(stream, ssl_context, server_hostname=hostname)
+                await stream.do_handshake()
+
+        if cancel_scope.cancelled_caught:
+            raise ConnectTimeout()
+
+        return TCPStream(stream=stream, timeout=timeout)
+
+    async def run_in_threadpool(
+        self, func: typing.Callable, *args: typing.Any, **kwargs: typing.Any
+    ) -> typing.Any:
+        return await trio.to_thread.run_sync(
+            functools.partial(func, **kwargs) if kwargs else func, *args
+        )
+
+    def run(
+        self, coroutine: typing.Callable, *args: typing.Any, **kwargs: typing.Any
+    ) -> typing.Any:
+        return trio.run(
+            functools.partial(coroutine, **kwargs) if kwargs else coroutine, *args
+        )
+
+    def get_semaphore(self, limits: PoolLimits) -> BasePoolSemaphore:
+        return PoolSemaphore(limits)
+
+    def create_queue(self, max_size: int) -> BaseQueue:
+        return Queue(max_size=max_size)
+
+    def create_event(self) -> BaseEvent:
+        return Event()
+
+    def background_manager(
+        self, coroutine: typing.Callable, *args: typing.Any
+    ) -> "BackgroundManager":
+        return BackgroundManager(coroutine, *args)
+
+
+class Queue(BaseQueue):
+    def __init__(self, max_size: int) -> None:
+        self.send_channel, self.receive_channel = trio.open_memory_channel(math.inf)
+
+    async def get(self) -> typing.Any:
+        return await self.receive_channel.receive()
+
+    async def put(self, value: typing.Any) -> None:
+        await self.send_channel.send(value)
+
+
+class Event(BaseEvent):
+    def __init__(self) -> None:
+        self._event = trio.Event()
+
+    def set(self) -> None:
+        self._event.set()
+
+    def is_set(self) -> bool:
+        return self._event.is_set()
+
+    async def wait(self) -> None:
+        await self._event.wait()
+
+    def clear(self) -> None:
+        # trio.Event.clear() was deprecated in Trio 0.12.
+        # https://github.com/python-trio/trio/issues/637
+        self._event = trio.Event()
+
+
+class BackgroundManager(BaseBackgroundManager):
+    def __init__(self, coroutine: typing.Callable, *args: typing.Any) -> None:
+        self.coroutine = coroutine
+        self.args = args
+        self.nursery_manager = trio.open_nursery()
+        self.nursery: typing.Optional[trio.Nursery] = None
+
+    async def __aenter__(self) -> "BackgroundManager":
+        self.nursery = await self.nursery_manager.__aenter__()
+        self.nursery.start_soon(self.coroutine, *self.args)
+        return self
+
+    async def __aexit__(
+        self,
+        exc_type: typing.Type[BaseException] = None,
+        exc_value: BaseException = None,
+        traceback: TracebackType = None,
+    ) -> None:
+        assert self.nursery is not None
+        await self.nursery_manager.__aexit__(exc_type, exc_value, traceback)
index c56d757c71b1926bbba57ace43816e99e8d66223..0633ea8f007f0ac74b82e0450157d372e4cdd352 100644 (file)
@@ -130,6 +130,7 @@ class ASGIDispatch(AsyncDispatcher):
         await response_started_or_failed.wait()
 
         if app_exc is not None and self.raise_app_exceptions:
+            await background.close(app_exc)
             raise app_exc
 
         assert status_code is not None, "application did not return a response."
@@ -138,7 +139,7 @@ class ASGIDispatch(AsyncDispatcher):
         async def on_close() -> None:
             nonlocal response_body
             await response_body.drain()
-            await background.__aexit__(None, None, None)
+            await background.close(app_exc)
             if app_exc is not None and self.raise_app_exceptions:
                 raise app_exc
 
index 4ddd29f6f7ade6c6696c5acd4d67dd6c7c79ecaf..b1efe6a223d321f7d534323196d544b511cdfdf6 100644 (file)
@@ -6,6 +6,7 @@ import h2.events
 
 from ..concurrency.base import BaseEvent, BaseTCPStream, ConcurrencyBackend, TimeoutFlag
 from ..config import TimeoutConfig, TimeoutTypes
+from ..exceptions import ProtocolError
 from ..models import AsyncRequest, AsyncResponse
 from ..utils import get_logger
 
@@ -187,6 +188,10 @@ class HTTP2Connection:
                 logger.debug(
                     f"receive_event stream_id={event_stream_id} event={event!r}"
                 )
+
+                if hasattr(event, "error_code"):
+                    raise ProtocolError(event)
+
                 if isinstance(event, h2.events.WindowUpdated):
                     if event_stream_id == 0:
                         for window_update_event in self.window_update_received.values():
index a9fc1a8fe773f08858af18ceacda6a3156ec5692..299bac64bffb75151beac46970833bd6ccbaf5d3 100644 (file)
--- a/setup.cfg
+++ b/setup.cfg
@@ -11,7 +11,7 @@ combine_as_imports = True
 force_grid_wrap = 0
 include_trailing_comma = True
 known_first_party = httpx,tests
-known_third_party = brotli,certifi,chardet,cryptography,h11,h2,hstspreload,nox,pytest,requests,rfc3986,setuptools,trustme,uvicorn
+known_third_party = brotli,certifi,chardet,cryptography,h11,h2,hstspreload,nox,pytest,requests,rfc3986,setuptools,trio,trustme,uvicorn
 line_length = 88
 multi_line_output = 3
 
index 8bd9d605f8571cd8c0d57e82566733a631e3ce1b..298d06f3165fdaef2382796a551ffa2988f75a64 100644 (file)
--- a/setup.py
+++ b/setup.py
@@ -59,6 +59,7 @@ setup(
         "idna==2.*",
         "rfc3986==1.*",
     ],
+    extras_require={"trio": ["trio"]},
     classifiers=[
         "Development Status :: 3 - Alpha",
         "Environment :: Web Environment",
index cbef3872c282c4a0b16729827ca4e780928a105d..870e55861a85948894dfb0d66d8daf145019336b 100644 (file)
@@ -1,4 +1,4 @@
--e .
+-e .[trio]
 
 # Optional
 brotlipy==0.7.*
@@ -11,6 +11,7 @@ isort
 mypy
 pytest
 pytest-asyncio
+pytest-trio
 pytest-cov
 trustme
 uvicorn
index 1240034d3443a8fa0704f72667fd6b62ac0b0411..99d5d3fca100613610cc3046a33e69075e84512a 100644 (file)
@@ -17,3 +17,15 @@ async def sleep(backend, seconds: int):
 @sleep.register(AsyncioBackend)
 async def _sleep_asyncio(backend, seconds: int):
     await asyncio.sleep(seconds)
+
+
+try:
+    import trio
+    from httpx.concurrency.trio import TrioBackend
+except ImportError:  # pragma: no cover
+    pass
+else:
+
+    @sleep.register(TrioBackend)
+    async def _sleep_trio(backend, seconds: int):
+        await trio.sleep(seconds)
index 658a6f943eec88a3d6175e54e2d6083f515fe7b1..ed9ea34b8f8a5c4da622ce9fa4f5ad0e21ead805 100644 (file)
@@ -47,7 +47,17 @@ def clean_environ() -> typing.Dict[str, typing.Any]:
     os.environ.update(original_environ)
 
 
-@pytest.fixture(params=[pytest.param(AsyncioBackend, marks=pytest.mark.asyncio)])
+backend_params = [pytest.param(AsyncioBackend, marks=pytest.mark.asyncio)]
+
+try:
+    from httpx.concurrency.trio import TrioBackend
+except ImportError:  # pragma: no cover
+    pass
+else:
+    backend_params.append(pytest.param(TrioBackend, marks=pytest.mark.trio))
+
+
+@pytest.fixture(params=backend_params)
 def backend(request):
     backend_cls = request.param
     return backend_cls()
index 276580875fa3a8517fa1d0793df507b9f87dd7fb..6f4361dd115310694e4824de84b74b4f65690b07 100644 (file)
@@ -168,7 +168,9 @@ async def test_connection_closed_free_semaphore_on_acquire(server, restart, back
     Verify that max_connections semaphore is released
     properly on a disconnected connection.
     """
-    async with httpx.ConnectionPool(pool_limits=httpx.PoolLimits(hard_limit=1)) as http:
+    async with httpx.ConnectionPool(
+        pool_limits=httpx.PoolLimits(hard_limit=1), backend=backend
+    ) as http:
         response = await http.request("GET", server.url)
         await response.read()
 
index 765c8f7d3a5cdcd1d746fa3426b3bed93d2736da..8dc1e0800f8a2dfcd84982496ce2e281e34496b3 100644 (file)
@@ -38,7 +38,7 @@ async def test_connect_timeout(server, backend):
 
 
 async def test_pool_timeout(server, backend):
-    pool_limits = PoolLimits(hard_limit=1, pool_timeout=1e-6)
+    pool_limits = PoolLimits(hard_limit=1, pool_timeout=1e-4)
 
     async with AsyncClient(pool_limits=pool_limits, backend=backend) as client:
         response = await client.get(server.url, stream=True)