]> git.ipfire.org Git - thirdparty/httpx.git/commitdiff
Type check tests (#1054)
authorFlorimond Manca <florimond.manca@gmail.com>
Tue, 7 Jul 2020 09:10:43 +0000 (11:10 +0200)
committerGitHub <noreply@github.com>
Tue, 7 Jul 2020 09:10:43 +0000 (11:10 +0200)
17 files changed:
httpx/_types.py
scripts/check
tests/client/test_async_client.py
tests/client/test_auth.py
tests/client/test_client.py
tests/client/test_cookies.py
tests/client/test_headers.py
tests/client/test_properties.py
tests/client/test_proxies.py
tests/client/test_queryparams.py
tests/client/test_redirects.py
tests/conftest.py
tests/test_api.py
tests/test_config.py
tests/test_content_streams.py
tests/test_multipart.py
tests/test_status_codes.py

index a74020a4aefea59a4579d1f8bc4c6820fd7cb0d5..d2fc098e249f1104fc5250c78f710fc01d650a31 100644 (file)
@@ -72,4 +72,4 @@ FileTypes = Union[
     # (filename, file (or text), content_type)
     Tuple[Optional[str], FileContent, Optional[str]],
 ]
-RequestFiles = Union[Mapping[str, FileTypes], List[Tuple[str, FileTypes]]]
+RequestFiles = Union[Mapping[str, FileTypes], Sequence[Tuple[str, FileTypes]]]
index 2b42506f6f7640b28d92ec10c2ac84faa4a794f1..f9fc19343bc8ffe7f0b5147418442fe8d03e78cc 100755 (executable)
@@ -10,5 +10,5 @@ set -x
 
 ${PREFIX}black --check --diff --target-version=py36 $SOURCE_FILES
 ${PREFIX}flake8 $SOURCE_FILES
-${PREFIX}mypy httpx
+${PREFIX}mypy $SOURCE_FILES
 ${PREFIX}isort --check --diff --project=httpx $SOURCE_FILES
index 649e428f5e4a3dc8714bc003ad46c172f57125df..6818b4a444b3b1c76ac1c0c5419237fca9edced9 100644 (file)
@@ -174,8 +174,11 @@ def test_dispatch_deprecated():
 
 
 def test_asgi_dispatch_deprecated():
+    async def app(scope, receive, send):
+        pass
+
     with pytest.warns(DeprecationWarning) as record:
-        ASGIDispatch(None)
+        ASGIDispatch(app)
 
     assert len(record) == 1
     assert (
index 818e65904a606c22f1e3af4490874b91d90c33e1..13184a06bac75d4d3cfd9648e2ac80395aca5a4c 100644 (file)
@@ -11,7 +11,6 @@ from httpx import (
     Auth,
     Client,
     DigestAuth,
-    Headers,
     ProtocolError,
     Request,
     RequestBodyUnavailable,
@@ -86,23 +85,26 @@ class MockDigestAuthTransport(httpcore.AsyncHTTPTransport):
     async def request(
         self,
         method: bytes,
-        url: typing.Tuple[bytes, bytes, int, bytes],
-        headers: typing.List[typing.Tuple[bytes, bytes]],
-        stream: ContentStream,
+        url: typing.Tuple[bytes, bytes, typing.Optional[int], bytes],
+        headers: typing.List[typing.Tuple[bytes, bytes]] = None,
+        stream: httpcore.AsyncByteStream = None,
         timeout: typing.Dict[str, typing.Optional[float]] = None,
     ) -> typing.Tuple[
         bytes, int, bytes, typing.List[typing.Tuple[bytes, bytes]], ContentStream
     ]:
         if self._response_count < self.send_response_after_attempt:
-            return self.challenge_send(method, url, headers, stream)
+            assert headers is not None
+            return self.challenge_send(method, headers)
 
         authorization = get_header_value(headers, "Authorization")
         body = JSONStream({"auth": authorization})
         return b"HTTP/1.1", 200, b"", [], body
 
     def challenge_send(
-        self, method: bytes, url: URL, headers: Headers, stream: ContentStream,
-    ) -> typing.Tuple[int, bytes, Headers, ContentStream]:
+        self, method: bytes, headers: typing.List[typing.Tuple[bytes, bytes]],
+    ) -> typing.Tuple[
+        bytes, int, bytes, typing.List[typing.Tuple[bytes, bytes]], ContentStream
+    ]:
         self._response_count += 1
         nonce = (
             hashlib.sha256(os.urandom(8)).hexdigest()
@@ -297,7 +299,8 @@ async def test_auth_hidden_header() -> None:
 async def test_auth_invalid_type() -> None:
     url = "https://example.org/"
     client = AsyncClient(
-        transport=AsyncMockTransport(), auth="not a tuple, not a callable",
+        transport=AsyncMockTransport(),
+        auth="not a tuple, not a callable",  # type: ignore
     )
     with pytest.raises(TypeError):
         await client.get(url)
index 02f78f6999c2fe058272ac5cbc47f43c9f450546..1426fc216c4b395be6de41a92ae25145a9cd3d90 100644 (file)
@@ -182,8 +182,11 @@ def test_dispatch_deprecated():
 
 
 def test_wsgi_dispatch_deprecated():
+    def app(start_response, environ):
+        pass
+
     with pytest.warns(DeprecationWarning) as record:
-        WSGIDispatch(None)
+        WSGIDispatch(app)
 
     assert len(record) == 1
     assert (
index a109ccc6eecb1ea59d1cb4fe6f554422d2c17aa6..68b6c64cf50a691f8da6a15a5e2cccacb1ceb107 100644 (file)
@@ -20,14 +20,15 @@ class MockTransport(httpcore.AsyncHTTPTransport):
     async def request(
         self,
         method: bytes,
-        url: typing.Tuple[bytes, bytes, int, bytes],
-        headers: typing.List[typing.Tuple[bytes, bytes]],
-        stream: ContentStream,
+        url: typing.Tuple[bytes, bytes, typing.Optional[int], bytes],
+        headers: typing.List[typing.Tuple[bytes, bytes]] = None,
+        stream: httpcore.AsyncByteStream = None,
         timeout: typing.Dict[str, typing.Optional[float]] = None,
     ) -> typing.Tuple[
         bytes, int, bytes, typing.List[typing.Tuple[bytes, bytes]], ContentStream
     ]:
         host, scheme, port, path = url
+        body: ContentStream
         if path.startswith(b"/echo_cookies"):
             cookie = get_header_value(headers, "cookie")
             body = JSONStream({"cookies": cookie})
index 6f26b1c7b34a10b07d81223db24b22bc1113d61c..2f87c38a1f2cbc94a5a5e8826d647852aa2679a9 100755 (executable)
@@ -13,13 +13,14 @@ class MockTransport(httpcore.AsyncHTTPTransport):
     async def request(
         self,
         method: bytes,
-        url: typing.Tuple[bytes, bytes, int, bytes],
-        headers: typing.List[typing.Tuple[bytes, bytes]],
-        stream: ContentStream,
+        url: typing.Tuple[bytes, bytes, typing.Optional[int], bytes],
+        headers: typing.List[typing.Tuple[bytes, bytes]] = None,
+        stream: httpcore.AsyncByteStream = None,
         timeout: typing.Dict[str, typing.Optional[float]] = None,
     ) -> typing.Tuple[
         bytes, int, bytes, typing.List[typing.Tuple[bytes, bytes]], ContentStream
     ]:
+        assert headers is not None
         headers_dict = {
             key.decode("ascii"): value.decode("ascii") for key, value in headers
         }
index 5dbbb4690ad6ee3ffc23064975947e260af718be..011c593cd3a5d3e2690bb054bab5df07cabbb590 100644 (file)
@@ -3,14 +3,14 @@ from httpx import AsyncClient, Cookies, Headers
 
 def test_client_headers():
     client = AsyncClient()
-    client.headers = {"a": "b"}
+    client.headers = {"a": "b"}  # type: ignore
     assert isinstance(client.headers, Headers)
     assert client.headers["A"] == "b"
 
 
 def test_client_cookies():
     client = AsyncClient()
-    client.cookies = {"a": "b"}
+    client.cookies = {"a": "b"}  # type: ignore
     assert isinstance(client.cookies, Cookies)
     mycookies = list(client.cookies.jar)
     assert len(mycookies) == 1
index 5222b08e342e59ce8588b7e9fa34641220986692..fb21760bf7ed5998761902d1dc813b4b9d2da978 100644 (file)
@@ -1,3 +1,4 @@
+import httpcore
 import pytest
 
 import httpx
@@ -24,7 +25,9 @@ def test_proxies_parameter(proxies, expected_proxies):
 
     for proxy_key, url in expected_proxies:
         assert proxy_key in client.proxies
-        assert client.proxies[proxy_key].proxy_origin == httpx.URL(url).raw[:3]
+        proxy = client.proxies[proxy_key]
+        assert isinstance(proxy, httpcore.AsyncHTTPProxy)
+        assert proxy.proxy_origin == httpx.URL(url).raw[:3]
 
     assert len(expected_proxies) == len(client.proxies)
 
@@ -81,6 +84,7 @@ def test_transport_for_request(url, proxies, expected):
     if expected is None:
         assert transport is client.transport
     else:
+        assert isinstance(transport, httpcore.AsyncHTTPProxy)
         assert transport.proxy_origin == httpx.URL(expected).raw[:3]
 
 
index 97e119962069b88f85024d6f04f7c7746e5af803..10a03539e2f64413d6d04ae634435396d83d0131 100644 (file)
@@ -3,7 +3,7 @@ import typing
 import httpcore
 import pytest
 
-from httpx import URL, AsyncClient, Headers, QueryParams
+from httpx import URL, AsyncClient, QueryParams
 from httpx._content_streams import ContentStream, JSONStream
 
 
@@ -11,16 +11,15 @@ class MockTransport(httpcore.AsyncHTTPTransport):
     async def request(
         self,
         method: bytes,
-        url: typing.Tuple[bytes, bytes, int, bytes],
-        headers: typing.List[typing.Tuple[bytes, bytes]],
-        stream: ContentStream,
+        url: typing.Tuple[bytes, bytes, typing.Optional[int], bytes],
+        headers: typing.List[typing.Tuple[bytes, bytes]] = None,
+        stream: httpcore.AsyncByteStream = None,
         timeout: typing.Dict[str, typing.Optional[float]] = None,
     ) -> typing.Tuple[
         bytes, int, bytes, typing.List[typing.Tuple[bytes, bytes]], ContentStream
     ]:
-        headers = Headers()
         body = JSONStream({"ok": "ok"})
-        return b"HTTP/1.1", 200, b"OK", headers, body
+        return b"HTTP/1.1", 200, b"OK", [], body
 
 
 def test_client_queryparams():
@@ -35,7 +34,7 @@ def test_client_queryparams_string():
     assert client.params["a"] == "b"
 
     client = AsyncClient()
-    client.params = "a=b"
+    client.params = "a=b"  # type: ignore
     assert isinstance(client.params, QueryParams)
     assert client.params["a"] == "b"
 
index ae800fa79299bf8a38f5225c9fdeb19e10bde9cb..30b6f6a128c37352c0fb628c8c786ac7b7ddc97c 100644 (file)
@@ -103,8 +103,8 @@ class MockTransport:
             headers_dict = {
                 key.decode("ascii"): value.decode("ascii") for key, value in headers
             }
-            content = ByteStream(json.dumps({"headers": headers_dict}).encode())
-            return b"HTTP/1.1", 200, b"OK", [], content
+            stream = ByteStream(json.dumps({"headers": headers_dict}).encode())
+            return b"HTTP/1.1", 200, b"OK", [], stream
 
         elif path == b"/redirect_body":
             code = codes.PERMANENT_REDIRECT
@@ -121,10 +121,10 @@ class MockTransport:
             headers_dict = {
                 key.decode("ascii"): value.decode("ascii") for key, value in headers
             }
-            body = ByteStream(
+            stream = ByteStream(
                 json.dumps({"body": content.decode(), "headers": headers_dict}).encode()
             )
-            return b"HTTP/1.1", 200, b"OK", [], body
+            return b"HTTP/1.1", 200, b"OK", [], stream
 
         elif path == b"/cross_subdomain":
             host = get_header_value(headers, "host")
@@ -402,9 +402,9 @@ class MockCookieTransport(httpcore.AsyncHTTPTransport):
     async def request(
         self,
         method: bytes,
-        url: typing.Tuple[bytes, bytes, int, bytes],
-        headers: typing.List[typing.Tuple[bytes, bytes]],
-        stream: ContentStream,
+        url: typing.Tuple[bytes, bytes, typing.Optional[int], bytes],
+        headers: typing.List[typing.Tuple[bytes, bytes]] = None,
+        stream: httpcore.AsyncByteStream = None,
         timeout: typing.Dict[str, typing.Optional[float]] = None,
     ) -> typing.Tuple[
         bytes, int, bytes, typing.List[typing.Tuple[bytes, bytes]], ContentStream
@@ -432,7 +432,8 @@ class MockCookieTransport(httpcore.AsyncHTTPTransport):
             ]
             return b"HTTP/1.1", status_code, b"See Other", headers, ByteStream(b"")
 
-        elif path == b"/logout":
+        else:
+            assert path == b"/logout"
             status_code = codes.SEE_OTHER
             headers = [
                 (b"location", b"/"),
index a145ce0fa08abb33256f0ece9a7769556e5d75f3..10576ebd8af6f7e84daa880f2e6f71739d7db894 100644 (file)
@@ -56,7 +56,7 @@ def async_environment(request: typing.Any) -> str:
 
 
 @pytest.fixture(scope="function", autouse=True)
-def clean_environ() -> typing.Dict[str, typing.Any]:
+def clean_environ():
     """Keeps os.environ clean for every test without having to mock os.environ"""
     original_environ = os.environ.copy()
     os.environ.clear()
index 4c1d61162032c4933216e2d1a7e9f1bb242e4c50..2d51d99e8a4860fc2bde85f5fe6fc7b7d4addd19 100644 (file)
@@ -68,7 +68,6 @@ def test_stream(server):
     assert response.http_version == "HTTP/1.1"
 
 
-@pytest.mark.asyncio
-async def test_get_invalid_url(server):
+def test_get_invalid_url():
     with pytest.raises(httpx.InvalidURL):
-        await httpx.get("invalid://example.org")
+        httpx.get("invalid://example.org")
index 41d81916ad5ffe9c870ad8944cf56de33ae57022..46d154cdb82810746f6d3bc2198f61f20f65e02c 100644 (file)
@@ -56,7 +56,7 @@ def test_load_ssl_config_verify_env_file(https_server, ca_cert_pem_file, config)
 
 def test_load_ssl_config_verify_directory():
     path = Path(certifi.where()).parent
-    ssl_config = SSLConfig(verify=path)
+    ssl_config = SSLConfig(verify=str(path))
     context = ssl_config.ssl_context
     assert context.verify_mode == ssl.VerifyMode.CERT_REQUIRED
     assert context.check_hostname is True
@@ -192,7 +192,7 @@ def test_ssl_config_support_for_keylog_file(tmpdir, monkeypatch):  # pragma: noc
 
         ssl_config = SSLConfig(trust_env=True)
 
-        assert ssl_config.ssl_context.keylog_filename is None
+        assert ssl_config.ssl_context.keylog_filename is None  # type: ignore
 
     filename = str(tmpdir.join("test.log"))
 
@@ -201,11 +201,11 @@ def test_ssl_config_support_for_keylog_file(tmpdir, monkeypatch):  # pragma: noc
 
         ssl_config = SSLConfig(trust_env=True)
 
-        assert ssl_config.ssl_context.keylog_filename == filename
+        assert ssl_config.ssl_context.keylog_filename == filename  # type: ignore
 
         ssl_config = SSLConfig(trust_env=False)
 
-        assert ssl_config.ssl_context.keylog_filename is None
+        assert ssl_config.ssl_context.keylog_filename is None  # type: ignore
 
 
 @pytest.mark.parametrize(
index 2b2adc92ae65ff245048cbff53362115191ea449..140aa8d2af4f4a36b786f03c0b7a12bf777f8705 100644 (file)
@@ -203,7 +203,7 @@ async def test_empty_request():
 
 def test_invalid_argument():
     with pytest.raises(TypeError):
-        encode(123)
+        encode(123)  # type: ignore
 
 
 @pytest.mark.asyncio
index fbabc7c483001e3fbdc8d0ecb1c854ece4b899d4..7d6f8e025d4c2d32cd985d41e53952a888c073fa 100644 (file)
@@ -8,7 +8,7 @@ import httpcore
 import pytest
 
 import httpx
-from httpx._content_streams import AsyncIteratorStream, encode
+from httpx._content_streams import AsyncIteratorStream, MultipartStream, encode
 from httpx._utils import format_form_param
 
 
@@ -16,7 +16,7 @@ class MockTransport(httpcore.AsyncHTTPTransport):
     async def request(
         self,
         method: bytes,
-        url: typing.Tuple[bytes, bytes, int, bytes],
+        url: typing.Tuple[bytes, bytes, typing.Optional[int], bytes],
         headers: typing.List[typing.Tuple[bytes, bytes]] = None,
         stream: httpcore.AsyncByteStream = None,
         timeout: typing.Dict[str, typing.Optional[float]] = None,
@@ -27,6 +27,7 @@ class MockTransport(httpcore.AsyncHTTPTransport):
         typing.List[typing.Tuple[bytes, bytes]],
         httpcore.AsyncByteStream,
     ]:
+        assert stream is not None
         content = AsyncIteratorStream(aiterator=(part async for part in stream))
         return b"HTTP/1.1", 200, b"OK", [], content
 
@@ -46,7 +47,10 @@ async def test_multipart(value, output):
     # bit grungy, but sufficient just for our testing purposes.
     boundary = response.request.headers["Content-Type"].split("boundary=")[-1]
     content_length = response.request.headers["Content-Length"]
-    pdict = {"boundary": boundary.encode("ascii"), "CONTENT-LENGTH": content_length}
+    pdict: dict = {
+        "boundary": boundary.encode("ascii"),
+        "CONTENT-LENGTH": content_length,
+    }
     multipart = cgi.parse_multipart(io.BytesIO(response.content), pdict)
 
     # Note that the expected return type for text fields
@@ -91,7 +95,10 @@ async def test_multipart_file_tuple():
     # bit grungy, but sufficient just for our testing purposes.
     boundary = response.request.headers["Content-Type"].split("boundary=")[-1]
     content_length = response.request.headers["Content-Length"]
-    pdict = {"boundary": boundary.encode("ascii"), "CONTENT-LENGTH": content_length}
+    pdict: dict = {
+        "boundary": boundary.encode("ascii"),
+        "CONTENT-LENGTH": content_length,
+    }
     multipart = cgi.parse_multipart(io.BytesIO(response.content), pdict)
 
     # Note that the expected return type for text fields
@@ -117,6 +124,7 @@ def test_multipart_encode(tmp_path: typing.Any) -> None:
         boundary = os.urandom(16).hex()
 
         stream = encode(data=data, files=files)
+        assert isinstance(stream, MultipartStream)
         assert stream.can_replay()
 
         assert stream.content_type == f"multipart/form-data; boundary={boundary}"
@@ -143,6 +151,7 @@ def test_multipart_encode_files_allows_filenames_as_none() -> None:
         boundary = os.urandom(16).hex()
 
         stream = encode(data={}, files=files)
+        assert isinstance(stream, MultipartStream)
         assert stream.can_replay()
 
         assert stream.content_type == f"multipart/form-data; boundary={boundary}"
@@ -169,6 +178,7 @@ def test_multipart_encode_files_guesses_correct_content_type(
         boundary = os.urandom(16).hex()
 
         stream = encode(data={}, files=files)
+        assert isinstance(stream, MultipartStream)
         assert stream.can_replay()
 
         assert stream.content_type == f"multipart/form-data; boundary={boundary}"
@@ -192,6 +202,7 @@ def test_multipart_encode_files_allows_bytes_or_str_content(
         boundary = os.urandom(16).hex()
 
         stream = encode(data={}, files=files)
+        assert isinstance(stream, MultipartStream)
         assert stream.can_replay()
 
         assert stream.content_type == f"multipart/form-data; boundary={boundary}"
@@ -226,7 +237,7 @@ def test_multipart_encode_non_seekable_filelike() -> None:
         yield b"Hello"
         yield b"World"
 
-    fileobj = IteratorIO(data())
+    fileobj: typing.Any = IteratorIO(data())
     files = {"file": fileobj}
     stream = encode(files=files, boundary=b"+++")
     assert not stream.can_replay()
index e62b3e067bb457378907ba1d60eab5bec4b03b7d..c53e95965d57d3c3f08f24a072703265b6184d00 100644 (file)
@@ -7,7 +7,7 @@ def test_status_code_as_int():
 
 
 def test_lowercase_status_code():
-    assert httpx.codes.not_found == 404
+    assert httpx.codes.not_found == 404  # type: ignore
 
 
 def test_reason_phrase_for_status_code():