From: Florimond Manca Date: Fri, 22 Nov 2019 08:34:09 +0000 (+0100) Subject: Fix race condition on stream.read (#535) X-Git-Tag: 0.8.0~4 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=a05ba2e9148c3f74c80c68a727b7e52b3d751c8c;p=thirdparty%2Fhttpx.git Fix race condition on stream.read (#535) * Fix race condition on stream.read * Refactor run_concurrently --- diff --git a/httpx/concurrency/asyncio.py b/httpx/concurrency/asyncio.py index a0163ed0..d3febbb3 100644 --- a/httpx/concurrency/asyncio.py +++ b/httpx/concurrency/asyncio.py @@ -51,6 +51,7 @@ class SocketStream(BaseSocketStream): self.stream_reader = stream_reader self.stream_writer = stream_writer self.timeout = timeout + self.read_lock = asyncio.Lock() self._inner: typing.Optional[SocketStream] = None @@ -144,8 +145,10 @@ class SocketStream(BaseSocketStream): should_raise = flag is None or flag.raise_on_read_timeout read_timeout = timeout.read_timeout if should_raise else 0.01 try: - data = await asyncio.wait_for(self.stream_reader.read(n), read_timeout) - break + async with self.read_lock: + data = await asyncio.wait_for( + self.stream_reader.read(n), read_timeout + ) except asyncio.TimeoutError: if should_raise: raise ReadTimeout() from None @@ -155,6 +158,8 @@ class SocketStream(BaseSocketStream): # doesn't seem to allow on 3.6. # See: https://github.com/encode/httpx/issues/382 await asyncio.sleep(0) + else: + break return data diff --git a/httpx/concurrency/trio.py b/httpx/concurrency/trio.py index c84b1e4b..f2ab919b 100644 --- a/httpx/concurrency/trio.py +++ b/httpx/concurrency/trio.py @@ -32,6 +32,7 @@ class SocketStream(BaseSocketStream): self.stream = stream self.timeout = timeout self.write_buffer = b"" + self.read_lock = trio.Lock() self.write_lock = trio.Lock() async def start_tls( @@ -74,7 +75,8 @@ class SocketStream(BaseSocketStream): read_timeout = _or_inf(timeout.read_timeout if should_raise else 0.01) with trio.move_on_after(read_timeout): - return await self.stream.receive_some(max_bytes=n) + async with self.read_lock: + return await self.stream.receive_some(max_bytes=n) if should_raise: raise ReadTimeout() from None diff --git a/tests/concurrency.py b/tests/concurrency.py index 99d5d3fc..0d7d1350 100644 --- a/tests/concurrency.py +++ b/tests/concurrency.py @@ -5,8 +5,12 @@ required as part of the ConcurrencyBackend API. import asyncio import functools +import typing + +import trio from httpx import AsyncioBackend +from httpx.concurrency.trio import TrioBackend @functools.singledispatch @@ -19,13 +23,24 @@ async def _sleep_asyncio(backend, seconds: int): await asyncio.sleep(seconds) -try: - import trio - from httpx.concurrency.trio import TrioBackend -except ImportError: # pragma: no cover - pass -else: +@sleep.register(TrioBackend) +async def _sleep_trio(backend, seconds: int): + await trio.sleep(seconds) + + +@functools.singledispatch +async def run_concurrently(backend, *coroutines: typing.Callable[[], typing.Awaitable]): + raise NotImplementedError # pragma: no cover + + +@run_concurrently.register(AsyncioBackend) +async def _run_concurrently_asyncio(backend, *coroutines): + coros = (coroutine() for coroutine in coroutines) + await asyncio.gather(*coros) + - @sleep.register(TrioBackend) - async def _sleep_trio(backend, seconds: int): - await trio.sleep(seconds) +@run_concurrently.register(TrioBackend) +async def _run_concurrently_trio(backend, *coroutines): + async with trio.open_nursery() as nursery: + for coroutine in coroutines: + nursery.start_soon(coroutine) diff --git a/tests/test_concurrency.py b/tests/test_concurrency.py index cf3844cf..878d6a89 100644 --- a/tests/test_concurrency.py +++ b/tests/test_concurrency.py @@ -3,6 +3,7 @@ import trio from httpx import AsyncioBackend, HTTPVersionConfig, SSLConfig, TimeoutConfig from httpx.concurrency.trio import TrioBackend +from tests.concurrency import run_concurrently @pytest.mark.parametrize( @@ -72,3 +73,19 @@ async def test_start_tls_on_socket_stream( finally: await stream.close() + + +async def test_concurrent_read(server, backend): + """ + Regression test for: https://github.com/encode/httpx/issues/527 + """ + stream = await backend.open_tcp_stream( + server.url.host, server.url.port, ssl_context=None, timeout=TimeoutConfig(5) + ) + try: + await stream.write(b"GET / HTTP/1.1\r\n\r\n") + await run_concurrently( + backend, lambda: stream.read(10), lambda: stream.read(10) + ) + finally: + await stream.close()