]> git.ipfire.org Git - thirdparty/httpx.git/commitdiff
Refactor ASGITransport.request() (#1021)
authorFlorimond Manca <florimond.manca@gmail.com>
Sat, 13 Jun 2020 17:59:09 +0000 (19:59 +0200)
committerGitHub <noreply@github.com>
Sat, 13 Jun 2020 17:59:09 +0000 (19:59 +0200)
httpx/_transports/asgi.py

index af03e24fee8bb39751f148dbe4c13dace177fa07..2e228209926b6ea53c24c0ac989a5eeea8b4b46b 100644 (file)
@@ -1,5 +1,4 @@
-import typing
-from typing import Callable, Dict, List, Optional, Tuple
+from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Union
 
 import httpcore
 import sniffio
@@ -7,11 +6,11 @@ import sniffio
 from .._content_streams import ByteStream
 from .._utils import warn_deprecated
 
-if typing.TYPE_CHECKING:  # pragma: no cover
+if TYPE_CHECKING:  # pragma: no cover
     import asyncio
     import trio
 
-    Event = typing.Union[asyncio.Event, trio.Event]
+    Event = Union[asyncio.Event, trio.Event]
 
 
 def create_event() -> "Event":
@@ -78,6 +77,10 @@ class ASGITransport(httpcore.AsyncHTTPTransport):
         stream: httpcore.AsyncByteStream = None,
         timeout: Dict[str, Optional[float]] = None,
     ) -> Tuple[bytes, int, bytes, List[Tuple[bytes, bytes]], httpcore.AsyncByteStream]:
+        headers = [] if headers is None else headers
+        stream = ByteStream(b"") if stream is None else stream
+
+        # ASGI scope.
         scheme, host, port, full_path = url
         path, _, query = full_path.partition(b"?")
         scope = {
@@ -93,20 +96,22 @@ class ASGITransport(httpcore.AsyncHTTPTransport):
             "client": self.client,
             "root_path": self.root_path,
         }
+
+        # Request.
+        request_body_chunks = stream.__aiter__()
+        request_complete = False
+
+        # Response.
         status_code = None
         response_headers = None
         body_parts = []
-        request_complete = False
         response_started = False
         response_complete = create_event()
 
-        headers = [] if headers is None else headers
-        stream = ByteStream(b"") if stream is None else stream
-
-        request_body_chunks = stream.__aiter__()
+        # ASGI callables.
 
         async def receive() -> dict:
-            nonlocal request_complete, response_complete
+            nonlocal request_complete
 
             if request_complete:
                 await response_complete.wait()
@@ -120,8 +125,7 @@ class ASGITransport(httpcore.AsyncHTTPTransport):
             return {"type": "http.request", "body": body, "more_body": True}
 
         async def send(message: dict) -> None:
-            nonlocal status_code, response_headers, body_parts
-            nonlocal response_started, response_complete
+            nonlocal status_code, response_headers, response_started
 
             if message["type"] == "http.response.start":
                 assert not response_started
@@ -144,7 +148,7 @@ class ASGITransport(httpcore.AsyncHTTPTransport):
         try:
             await self.app(scope, receive, send)
         except Exception:
-            if self.raise_app_exceptions or not response_complete:
+            if self.raise_app_exceptions or not response_complete.is_set():
                 raise
 
         assert response_complete.is_set()