]> git.ipfire.org Git - thirdparty/httpx.git/commitdiff
Typing: use wsgiref.types to validate types and fix issues uncovered (#2467)
authorAdrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com>
Tue, 29 Nov 2022 17:55:21 +0000 (11:55 -0600)
committerGitHub <noreply@github.com>
Tue, 29 Nov 2022 17:55:21 +0000 (17:55 +0000)
* Typing: use wsgiref.types to validate types and fix issues uncovered

- start_response() must return a write(bytes) function, even though this
  is now deprecated. It's fine to be a no-op here.
- sys.exc_info() can return (None, None, None), so make sure to handle that case.

* remove typing_extensions

Co-authored-by: Martijn Pieters <mj@zopatista.com>
httpx/_transports/wsgi.py

index c7e3801a3448763829fc6c4bdb6710a01a7e480d..f27a77aea6c3f9aec8bac9344b32c4a413dc850b 100644 (file)
@@ -1,14 +1,31 @@
 import io
 import itertools
 import sys
+import types
 import typing
 
 from .._models import Request, Response
 from .._types import SyncByteStream
 from .base import BaseTransport
 
+_T = typing.TypeVar("_T")
+_ExcInfo = typing.Tuple[typing.Type[BaseException], BaseException, types.TracebackType]
+_OptExcInfo = typing.Union[_ExcInfo, typing.Tuple[None, None, None]]
 
-def _skip_leading_empty_chunks(body: typing.Iterable[bytes]) -> typing.Iterable[bytes]:
+
+# backported wsgiref.types definitions from Python 3.11
+StartResponse = typing.Callable[
+    [str, typing.List[typing.Tuple[str, str]], typing.Optional[_OptExcInfo]],
+    typing.Callable[[bytes], object],
+]
+
+
+WSGIApplication = typing.Callable[
+    [typing.Dict[str, typing.Any], StartResponse], typing.Iterable[bytes]
+]
+
+
+def _skip_leading_empty_chunks(body: typing.Iterable[_T]) -> typing.Iterable[_T]:
     body = iter(body)
     for chunk in body:
         if chunk:
@@ -54,7 +71,7 @@ class WSGITransport(BaseTransport):
 
     Arguments:
 
-    * `app` - The ASGI application.
+    * `app` - The WSGI application.
     * `raise_app_exceptions` - Boolean indicating if exceptions in the application
        should be raised. Default to `True`. Can be set to `False` for use cases
        such as testing the content of a client 500 response.
@@ -65,7 +82,7 @@ class WSGITransport(BaseTransport):
 
     def __init__(
         self,
-        app: typing.Callable[..., typing.Any],
+        app: WSGIApplication,
         raise_app_exceptions: bool = True,
         script_name: str = "",
         remote_addr: str = "127.0.0.1",
@@ -111,12 +128,13 @@ class WSGITransport(BaseTransport):
         def start_response(
             status: str,
             response_headers: typing.List[typing.Tuple[str, str]],
-            exc_info: typing.Any = None,
-        ) -> None:
+            exc_info: typing.Optional[_OptExcInfo] = None,
+        ) -> typing.Callable[[bytes], typing.Any]:
             nonlocal seen_status, seen_response_headers, seen_exc_info
             seen_status = status
             seen_response_headers = response_headers
             seen_exc_info = exc_info
+            return lambda _: None
 
         result = self.app(environ, start_response)
 
@@ -124,7 +142,7 @@ class WSGITransport(BaseTransport):
 
         assert seen_status is not None
         assert seen_response_headers is not None
-        if seen_exc_info and self.raise_app_exceptions:
+        if seen_exc_info and seen_exc_info[0] and self.raise_app_exceptions:
             raise seen_exc_info[1]
 
         status_code = int(seen_status.split()[0])