]> git.ipfire.org Git - thirdparty/httpx.git/commitdiff
Typing: always fill in generic type parameters (#2468)
authorAdrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com>
Tue, 29 Nov 2022 16:36:03 +0000 (10:36 -0600)
committerGitHub <noreply@github.com>
Tue, 29 Nov 2022 16:36:03 +0000 (16:36 +0000)
* Typing: always fill in generic type parameters

Being explicit about the parameters helps find bugs and makes the library
easier to use for users.

- Tell mypy to disallow generics without parameter values
- Give all generic types parameters values

* fix things that aren't coming in from other commits

* lint

Co-authored-by: Martijn Pieters <mj@zopatista.com>
Co-authored-by: Tom Christie <tom@tomchristie.com>
httpx/_auth.py
httpx/_main.py
httpx/_models.py
httpx/_transports/asgi.py
httpx/_transports/mock.py
httpx/_transports/wsgi.py
httpx/_urls.py
httpx/_utils.py
setup.cfg
tests/client/test_auth.py

index 2b00b49d17d459983c187a11a857121c12ad424d..0f54be9b407a5e847916f4c8dfe0bb809bc102be 100644 (file)
@@ -10,6 +10,9 @@ from ._exceptions import ProtocolError
 from ._models import Request, Response
 from ._utils import to_bytes, to_str, unquote
 
+if typing.TYPE_CHECKING:  # pragma: no cover
+    from hashlib import _Hash
+
 
 class Auth:
     """
@@ -139,7 +142,7 @@ class BasicAuth(Auth):
 
 
 class DigestAuth(Auth):
-    _ALGORITHM_TO_HASH_FUNCTION: typing.Dict[str, typing.Callable] = {
+    _ALGORITHM_TO_HASH_FUNCTION: typing.Dict[str, typing.Callable[[bytes], "_Hash"]] = {
         "MD5": hashlib.md5,
         "MD5-SESS": hashlib.md5,
         "SHA": hashlib.sha1,
index a2d5a2f52f580c5fc9bba22663016c9930334a6c..ba79c69e45295c631405cb9c8e521209c8f95a9a 100644 (file)
@@ -179,7 +179,12 @@ def print_response(response: Response) -> None:
         console.print(f"<{len(response.content)} bytes of binary data>")
 
 
-def format_certificate(cert: dict) -> str:  # pragma: no cover
+_PCTRTT = typing.Tuple[typing.Tuple[str, str], ...]
+_PCTRTTT = typing.Tuple[_PCTRTT, ...]
+_PeerCertRetDictType = typing.Dict[str, typing.Union[str, _PCTRTTT, _PCTRTT]]
+
+
+def format_certificate(cert: _PeerCertRetDictType) -> str:  # pragma: no cover
     lines = []
     for key, value in cert.items():
         if isinstance(value, (list, tuple)):
index fcdcf865558deff53dbd71a18c6ca4873e5af6df..e3370369a26c8a85f643aa80fc0ae7d0600d6955 100644 (file)
@@ -3,7 +3,7 @@ import email.message
 import json as jsonlib
 import typing
 import urllib.request
-from collections.abc import Mapping, MutableMapping
+from collections.abc import Mapping
 from http.cookiejar import Cookie, CookieJar
 
 from ._content import ByteStream, UnattachedStream, encode_request, encode_response
@@ -1002,7 +1002,7 @@ class Response:
                 await self.stream.aclose()
 
 
-class Cookies(MutableMapping):
+class Cookies(typing.MutableMapping[str, str]):
     """
     HTTP Cookies, as a mutable mapping.
     """
index 711a6f6ce75738685e739d76315f0ce3818f169a..bdf7f7a145f3e6e2dca71de2dfac359bc34090bc 100644 (file)
@@ -14,6 +14,16 @@ if typing.TYPE_CHECKING:  # pragma: no cover
     Event = typing.Union[asyncio.Event, trio.Event]
 
 
+_Message = typing.Dict[str, typing.Any]
+_Receive = typing.Callable[[], typing.Awaitable[_Message]]
+_Send = typing.Callable[
+    [typing.Dict[str, typing.Any]], typing.Coroutine[None, None, None]
+]
+_ASGIApp = typing.Callable[
+    [typing.Dict[str, typing.Any], _Receive, _Send], typing.Coroutine[None, None, None]
+]
+
+
 def create_event() -> "Event":
     if sniffio.current_async_library() == "trio":
         import trio
@@ -68,7 +78,7 @@ class ASGITransport(AsyncBaseTransport):
 
     def __init__(
         self,
-        app: typing.Callable,
+        app: _ASGIApp,
         raise_app_exceptions: bool = True,
         root_path: str = "",
         client: typing.Tuple[str, int] = ("127.0.0.1", 123),
@@ -113,7 +123,7 @@ class ASGITransport(AsyncBaseTransport):
 
         # ASGI callables.
 
-        async def receive() -> dict:
+        async def receive() -> typing.Dict[str, typing.Any]:
             nonlocal request_complete
 
             if request_complete:
@@ -127,7 +137,7 @@ class ASGITransport(AsyncBaseTransport):
                 return {"type": "http.request", "body": b"", "more_body": False}
             return {"type": "http.request", "body": body, "more_body": True}
 
-        async def send(message: dict) -> None:
+        async def send(message: typing.Dict[str, typing.Any]) -> None:
             nonlocal status_code, response_headers, response_started
 
             if message["type"] == "http.response.start":
index f61aee710114cb0b3740abd7c55eba2dfc5d2a92..8a70dfe142634cc1659db176540a72caa3c85346 100644 (file)
@@ -6,7 +6,7 @@ from .base import AsyncBaseTransport, BaseTransport
 
 
 class MockTransport(AsyncBaseTransport, BaseTransport):
-    def __init__(self, handler: typing.Callable) -> None:
+    def __init__(self, handler: typing.Callable[[Request], Response]) -> None:
         self.handler = handler
 
     def handle_request(
@@ -29,6 +29,6 @@ class MockTransport(AsyncBaseTransport, BaseTransport):
 
         # https://simonwillison.net/2020/Sep/2/await-me-maybe/
         if asyncio.iscoroutine(response):
-            response = await response
+            response = await response  # type: ignore[func-returns-value,assignment]
 
         return response
index 3dedf49f96af8f4176f224a9c7322a0ec2008dae..c7e3801a3448763829fc6c4bdb6710a01a7e480d 100644 (file)
@@ -8,7 +8,7 @@ from .._types import SyncByteStream
 from .base import BaseTransport
 
 
-def _skip_leading_empty_chunks(body: typing.Iterable) -> typing.Iterable:
+def _skip_leading_empty_chunks(body: typing.Iterable[bytes]) -> typing.Iterable[bytes]:
     body = iter(body)
     for chunk in body:
         if chunk:
@@ -65,7 +65,7 @@ class WSGITransport(BaseTransport):
 
     def __init__(
         self,
-        app: typing.Callable,
+        app: typing.Callable[..., typing.Any],
         raise_app_exceptions: bool = True,
         script_name: str = "",
         remote_addr: str = "127.0.0.1",
@@ -109,7 +109,9 @@ class WSGITransport(BaseTransport):
         seen_exc_info = None
 
         def start_response(
-            status: str, response_headers: list, exc_info: typing.Any = None
+            status: str,
+            response_headers: typing.List[typing.Tuple[str, str]],
+            exc_info: typing.Any = None,
         ) -> None:
             nonlocal seen_status, seen_response_headers, seen_exc_info
             seen_status = status
index 1211bbba9a33d202e2a7268256cf6c759a6b40e8..05db1652521a466b73f77e2588bb8cd03261bd6e 100644 (file)
@@ -570,7 +570,7 @@ class QueryParams(typing.Mapping[str, str]):
                 for k, v in dict_value.items()
             }
 
-    def keys(self) -> typing.KeysView:
+    def keys(self) -> typing.KeysView[str]:
         """
         Return all the keys in the query params.
 
@@ -581,7 +581,7 @@ class QueryParams(typing.Mapping[str, str]):
         """
         return self._dict.keys()
 
-    def values(self) -> typing.ValuesView:
+    def values(self) -> typing.ValuesView[str]:
         """
         Return all the values in the query params. If a key occurs more than once
         only the first item for that key is returned.
@@ -593,7 +593,7 @@ class QueryParams(typing.Mapping[str, str]):
         """
         return {k: v[0] for k, v in self._dict.items()}.values()
 
-    def items(self) -> typing.ItemsView:
+    def items(self) -> typing.ItemsView[str, str]:
         """
         Return all items in the query params. If a key occurs more than once
         only the first item for that key is returned.
index b2c3cbd4f4784266ed51e2bed5a4cf9fb13c753d..01eaacedb3678571a08ea76a13f06eea24132cf2 100644 (file)
@@ -508,7 +508,7 @@ class URLPattern:
         return True
 
     @property
-    def priority(self) -> tuple:
+    def priority(self) -> typing.Tuple[int, int, int]:
         """
         The priority allows URLPattern instances to be sortable, so that
         we can match from most specific to least specific.
index 3b085d087c017dd73de72ed1ed4163a665389211..671b0812456f68db3b65069c4655ab700994c002 100644 (file)
--- a/setup.cfg
+++ b/setup.cfg
@@ -4,6 +4,7 @@ max-line-length = 120
 
 [mypy]
 disallow_untyped_defs = True
+disallow_any_generics = True
 ignore_missing_imports = True
 no_implicit_optional = True
 show_error_codes = True
index bbb5ad9dc4d0ed832882766046d198abe0bc7499..735205c3aa33b4b6b7eb253645cc75d65772a6df 100644 (file)
@@ -428,7 +428,7 @@ async def test_digest_auth(
     assert response.status_code == 200
     assert len(response.history) == 1
 
-    authorization = typing.cast(dict, response.json())["auth"]
+    authorization = typing.cast(typing.Dict[str, typing.Any], response.json())["auth"]
     scheme, _, fields = authorization.partition(" ")
     assert scheme == "Digest"
 
@@ -459,7 +459,7 @@ async def test_digest_auth_no_specified_qop() -> None:
     assert response.status_code == 200
     assert len(response.history) == 1
 
-    authorization = typing.cast(dict, response.json())["auth"]
+    authorization = typing.cast(typing.Dict[str, typing.Any], response.json())["auth"]
     scheme, _, fields = authorization.partition(" ")
     assert scheme == "Digest"