]> git.ipfire.org Git - thirdparty/httpx.git/commitdiff
ASGI: Wait for response to complete before sending disconnect message (#919)
authorJamie Hewland <jhewland@gmail.com>
Tue, 12 May 2020 09:06:53 +0000 (11:06 +0200)
committerGitHub <noreply@github.com>
Tue, 12 May 2020 09:06:53 +0000 (10:06 +0100)
* asgi: Wait for response to complete before sending disconnect message

* Dial back type checking + remove concurrency module

* Remove somewhat redundant comment

httpx/_dispatch/asgi.py
requirements.txt
setup.py
tests/test_asgi.py

index 5edca1ed8c5a5e69766eb034506168d239b68e12..a86969bccacb2fbe13f570ab735379061c668c87 100644 (file)
@@ -1,9 +1,28 @@
+import typing
 from typing import Callable, Dict, List, Optional, Tuple
 
 import httpcore
+import sniffio
 
 from .._content_streams import ByteStream
 
+if typing.TYPE_CHECKING:  # pragma: no cover
+    import asyncio
+    import trio
+
+    Event = typing.Union[asyncio.Event, trio.Event]
+
+
+def create_event() -> "Event":
+    if sniffio.current_async_library() == "trio":
+        import trio
+
+        return trio.Event()
+    else:
+        import asyncio
+
+        return asyncio.Event()
+
 
 class ASGIDispatch(httpcore.AsyncHTTPTransport):
     """
@@ -76,8 +95,9 @@ class ASGIDispatch(httpcore.AsyncHTTPTransport):
         status_code = None
         response_headers = None
         body_parts = []
+        request_complete = False
         response_started = False
-        response_complete = False
+        response_complete = create_event()
 
         headers = [] if headers is None else headers
         stream = ByteStream(b"") if stream is None else stream
@@ -85,14 +105,16 @@ class ASGIDispatch(httpcore.AsyncHTTPTransport):
         request_body_chunks = stream.__aiter__()
 
         async def receive() -> dict:
-            nonlocal response_complete
+            nonlocal request_complete, response_complete
 
-            if response_complete:
+            if request_complete:
+                await response_complete.wait()
                 return {"type": "http.disconnect"}
 
             try:
                 body = await request_body_chunks.__anext__()
             except StopAsyncIteration:
+                request_complete = True
                 return {"type": "http.request", "body": b"", "more_body": False}
             return {"type": "http.request", "body": body, "more_body": True}
 
@@ -108,7 +130,7 @@ class ASGIDispatch(httpcore.AsyncHTTPTransport):
                 response_started = True
 
             elif message["type"] == "http.response.body":
-                assert not response_complete
+                assert not response_complete.is_set()
                 body = message.get("body", b"")
                 more_body = message.get("more_body", False)
 
@@ -116,7 +138,7 @@ class ASGIDispatch(httpcore.AsyncHTTPTransport):
                     body_parts.append(body)
 
                 if not more_body:
-                    response_complete = True
+                    response_complete.set()
 
         try:
             await self.app(scope, receive, send)
@@ -124,7 +146,7 @@ class ASGIDispatch(httpcore.AsyncHTTPTransport):
             if self.raise_app_exceptions or not response_complete:
                 raise
 
-        assert response_complete
+        assert response_complete.is_set()
         assert status_code is not None
         assert response_headers is not None
 
index e5ac1a2ed9f0c8c08a22f3d59ffb4908b1bdb736..dd2409067dff8764bb744b6362ce1b5ea9367da8 100644 (file)
@@ -26,6 +26,7 @@ pytest-asyncio
 pytest-trio
 pytest-cov
 trio
+trio-typing
 trustme
 uvicorn
 seed-isort-config
index 554fed604c4fbee4b9ed88235fa0277341c272a1..6a5b137eda5b8eb709cc2dfdd7d7dd8782bc2bd1 100644 (file)
--- a/setup.py
+++ b/setup.py
@@ -61,6 +61,7 @@ setup(
         "idna==2.*",
         "rfc3986>=1.3,<2",
         "httpcore>=0.8.3",
+        "sniffio",
     ],
     classifiers=[
         "Development Status :: 4 - Beta",
index 72a003936e1aa1dbd6fd96c67d38aa2d40ceab06..d225baf41178c2b4ae6515f26cef1a8bf73bc918 100644 (file)
@@ -69,8 +69,7 @@ async def test_asgi_exc_after_response():
         await client.get("http://www.example.org/")
 
 
-@pytest.mark.asyncio
-async def test_asgi_disconnect_after_response_complete():
+async def test_asgi_disconnect_after_response_complete(async_environment):
     disconnect = False
 
     async def read_body(scope, receive, send):