]> git.ipfire.org Git - thirdparty/httpx.git/commitdiff
Fix race condition on stream.read (#535)
authorFlorimond Manca <florimond.manca@gmail.com>
Fri, 22 Nov 2019 08:34:09 +0000 (09:34 +0100)
committerGitHub <noreply@github.com>
Fri, 22 Nov 2019 08:34:09 +0000 (09:34 +0100)
* Fix race condition on stream.read

* Refactor run_concurrently

httpx/concurrency/asyncio.py
httpx/concurrency/trio.py
tests/concurrency.py
tests/test_concurrency.py

index a0163ed02168ad05df947880cba9b30046ed5a11..d3febbb39dc3b897efba6885a3cf99fd2cc42929 100644 (file)
@@ -51,6 +51,7 @@ class SocketStream(BaseSocketStream):
         self.stream_reader = stream_reader
         self.stream_writer = stream_writer
         self.timeout = timeout
+        self.read_lock = asyncio.Lock()
 
         self._inner: typing.Optional[SocketStream] = None
 
@@ -144,8 +145,10 @@ class SocketStream(BaseSocketStream):
             should_raise = flag is None or flag.raise_on_read_timeout
             read_timeout = timeout.read_timeout if should_raise else 0.01
             try:
-                data = await asyncio.wait_for(self.stream_reader.read(n), read_timeout)
-                break
+                async with self.read_lock:
+                    data = await asyncio.wait_for(
+                        self.stream_reader.read(n), read_timeout
+                    )
             except asyncio.TimeoutError:
                 if should_raise:
                     raise ReadTimeout() from None
@@ -155,6 +158,8 @@ class SocketStream(BaseSocketStream):
                 # doesn't seem to allow on 3.6.
                 # See: https://github.com/encode/httpx/issues/382
                 await asyncio.sleep(0)
+            else:
+                break
 
         return data
 
index c84b1e4b8b2089c8c23dacaea614565d5bc9b193..f2ab919b8aa3775ef3192214eb7355ff316d313f 100644 (file)
@@ -32,6 +32,7 @@ class SocketStream(BaseSocketStream):
         self.stream = stream
         self.timeout = timeout
         self.write_buffer = b""
+        self.read_lock = trio.Lock()
         self.write_lock = trio.Lock()
 
     async def start_tls(
@@ -74,7 +75,8 @@ class SocketStream(BaseSocketStream):
             read_timeout = _or_inf(timeout.read_timeout if should_raise else 0.01)
 
             with trio.move_on_after(read_timeout):
-                return await self.stream.receive_some(max_bytes=n)
+                async with self.read_lock:
+                    return await self.stream.receive_some(max_bytes=n)
 
             if should_raise:
                 raise ReadTimeout() from None
index 99d5d3fca100613610cc3046a33e69075e84512a..0d7d1350831a8f4e153454c158a4a7cb2d7ab5c1 100644 (file)
@@ -5,8 +5,12 @@ required as part of the ConcurrencyBackend API.
 
 import asyncio
 import functools
+import typing
+
+import trio
 
 from httpx import AsyncioBackend
+from httpx.concurrency.trio import TrioBackend
 
 
 @functools.singledispatch
@@ -19,13 +23,24 @@ async def _sleep_asyncio(backend, seconds: int):
     await asyncio.sleep(seconds)
 
 
-try:
-    import trio
-    from httpx.concurrency.trio import TrioBackend
-except ImportError:  # pragma: no cover
-    pass
-else:
+@sleep.register(TrioBackend)
+async def _sleep_trio(backend, seconds: int):
+    await trio.sleep(seconds)
+
+
+@functools.singledispatch
+async def run_concurrently(backend, *coroutines: typing.Callable[[], typing.Awaitable]):
+    raise NotImplementedError  # pragma: no cover
+
+
+@run_concurrently.register(AsyncioBackend)
+async def _run_concurrently_asyncio(backend, *coroutines):
+    coros = (coroutine() for coroutine in coroutines)
+    await asyncio.gather(*coros)
+
 
-    @sleep.register(TrioBackend)
-    async def _sleep_trio(backend, seconds: int):
-        await trio.sleep(seconds)
+@run_concurrently.register(TrioBackend)
+async def _run_concurrently_trio(backend, *coroutines):
+    async with trio.open_nursery() as nursery:
+        for coroutine in coroutines:
+            nursery.start_soon(coroutine)
index cf3844cffedf33b7e69dd8f88e8b8f5755fb5894..878d6a8933b51e7919a7a154d0aeb87e5f02d16a 100644 (file)
@@ -3,6 +3,7 @@ import trio
 
 from httpx import AsyncioBackend, HTTPVersionConfig, SSLConfig, TimeoutConfig
 from httpx.concurrency.trio import TrioBackend
+from tests.concurrency import run_concurrently
 
 
 @pytest.mark.parametrize(
@@ -72,3 +73,19 @@ async def test_start_tls_on_socket_stream(
 
     finally:
         await stream.close()
+
+
+async def test_concurrent_read(server, backend):
+    """
+    Regression test for: https://github.com/encode/httpx/issues/527
+    """
+    stream = await backend.open_tcp_stream(
+        server.url.host, server.url.port, ssl_context=None, timeout=TimeoutConfig(5)
+    )
+    try:
+        await stream.write(b"GET / HTTP/1.1\r\n\r\n")
+        await run_concurrently(
+            backend, lambda: stream.read(10), lambda: stream.read(10)
+        )
+    finally:
+        await stream.close()