]> git.ipfire.org Git - thirdparty/starlette.git/commitdiff
Add type hints to `test_responses.py` (#2488)
authorScirlat Danut <danut.scirlat@gmail.com>
Fri, 9 Feb 2024 09:29:28 +0000 (11:29 +0200)
committerGitHub <noreply@github.com>
Fri, 9 Feb 2024 09:29:28 +0000 (09:29 +0000)
* Add type hints to test_responses.py

* Update tests/test_responses.py

* Linter fix

* Apply suggestions from code review

* Apply suggestions from code review

---------

Co-authored-by: Scirlat Danut <scirlatdanut@scirlats-mini.lan>
Co-authored-by: Marcelo Trylesinski <marcelotryle@gmail.com>
tests/test_responses.py

index 291c46e6def70a944b40e3f67d84b7e4bc3be5b5..57a594901815b9bb283affabe8af320b545cee9b 100644 (file)
@@ -3,6 +3,7 @@ import os
 import time
 from http.cookies import SimpleCookie
 from pathlib import Path
+from typing import AsyncIterator, Callable, Iterator, Union
 
 import anyio
 import pytest
@@ -19,11 +20,13 @@ from starlette.responses import (
     StreamingResponse,
 )
 from starlette.testclient import TestClient
-from starlette.types import Message
+from starlette.types import Message, Receive, Scope, Send
 
+TestClientFactory = Callable[..., TestClient]
 
-def test_text_response(test_client_factory):
-    async def app(scope, receive, send):
+
+def test_text_response(test_client_factory: TestClientFactory) -> None:
+    async def app(scope: Scope, receive: Receive, send: Send) -> None:
         response = Response("hello, world", media_type="text/plain")
         await response(scope, receive, send)
 
@@ -32,8 +35,8 @@ def test_text_response(test_client_factory):
     assert response.text == "hello, world"
 
 
-def test_bytes_response(test_client_factory):
-    async def app(scope, receive, send):
+def test_bytes_response(test_client_factory: TestClientFactory) -> None:
+    async def app(scope: Scope, receive: Receive, send: Send) -> None:
         response = Response(b"xxxxx", media_type="image/png")
         await response(scope, receive, send)
 
@@ -42,8 +45,8 @@ def test_bytes_response(test_client_factory):
     assert response.content == b"xxxxx"
 
 
-def test_json_none_response(test_client_factory):
-    async def app(scope, receive, send):
+def test_json_none_response(test_client_factory: TestClientFactory) -> None:
+    async def app(scope: Scope, receive: Receive, send: Send) -> None:
         response = JSONResponse(None)
         await response(scope, receive, send)
 
@@ -53,8 +56,8 @@ def test_json_none_response(test_client_factory):
     assert response.content == b"null"
 
 
-def test_redirect_response(test_client_factory):
-    async def app(scope, receive, send):
+def test_redirect_response(test_client_factory: TestClientFactory) -> None:
+    async def app(scope: Scope, receive: Receive, send: Send) -> None:
         if scope["path"] == "/":
             response = Response("hello, world", media_type="text/plain")
         else:
@@ -67,8 +70,8 @@ def test_redirect_response(test_client_factory):
     assert response.url == "http://testserver/"
 
 
-def test_quoting_redirect_response(test_client_factory):
-    async def app(scope, receive, send):
+def test_quoting_redirect_response(test_client_factory: TestClientFactory) -> None:
+    async def app(scope: Scope, receive: Receive, send: Send) -> None:
         if scope["path"] == "/I ♥ Starlette/":
             response = Response("hello, world", media_type="text/plain")
         else:
@@ -81,8 +84,10 @@ def test_quoting_redirect_response(test_client_factory):
     assert response.url == "http://testserver/I%20%E2%99%A5%20Starlette/"
 
 
-def test_redirect_response_content_length_header(test_client_factory):
-    async def app(scope, receive, send):
+def test_redirect_response_content_length_header(
+    test_client_factory: TestClientFactory,
+) -> None:
+    async def app(scope: Scope, receive: Receive, send: Send) -> None:
         if scope["path"] == "/":
             response = Response("hello", media_type="text/plain")  # pragma: nocover
         else:
@@ -95,18 +100,18 @@ def test_redirect_response_content_length_header(test_client_factory):
     assert response.headers["content-length"] == "0"
 
 
-def test_streaming_response(test_client_factory):
+def test_streaming_response(test_client_factory: TestClientFactory) -> None:
     filled_by_bg_task = ""
 
-    async def app(scope, receive, send):
-        async def numbers(minimum, maximum):
+    async def app(scope: Scope, receive: Receive, send: Send) -> None:
+        async def numbers(minimum: int, maximum: int) -> AsyncIterator[str]:
             for i in range(minimum, maximum + 1):
                 yield str(i)
                 if i != maximum:
                     yield ", "
                 await anyio.sleep(0)
 
-        async def numbers_for_cleanup(start=1, stop=5):
+        async def numbers_for_cleanup(start: int = 1, stop: int = 5) -> None:
             nonlocal filled_by_bg_task
             async for thing in numbers(start, stop):
                 filled_by_bg_task = filled_by_bg_task + thing
@@ -125,16 +130,18 @@ def test_streaming_response(test_client_factory):
     assert filled_by_bg_task == "6, 7, 8, 9"
 
 
-def test_streaming_response_custom_iterator(test_client_factory):
-    async def app(scope, receive, send):
+def test_streaming_response_custom_iterator(
+    test_client_factory: TestClientFactory,
+) -> None:
+    async def app(scope: Scope, receive: Receive, send: Send) -> None:
         class CustomAsyncIterator:
-            def __init__(self):
+            def __init__(self) -> None:
                 self._called = 0
 
-            def __aiter__(self):
+            def __aiter__(self) -> AsyncIterator[str]:
                 return self
 
-            async def __anext__(self):
+            async def __anext__(self) -> str:
                 if self._called == 5:
                     raise StopAsyncIteration()
                 self._called += 1
@@ -148,10 +155,12 @@ def test_streaming_response_custom_iterator(test_client_factory):
     assert response.text == "12345"
 
 
-def test_streaming_response_custom_iterable(test_client_factory):
-    async def app(scope, receive, send):
+def test_streaming_response_custom_iterable(
+    test_client_factory: TestClientFactory,
+) -> None:
+    async def app(scope: Scope, receive: Receive, send: Send) -> None:
         class CustomAsyncIterable:
-            async def __aiter__(self):
+            async def __aiter__(self) -> AsyncIterator[Union[str, bytes]]:
                 for i in range(5):
                     yield str(i + 1)
 
@@ -163,9 +172,9 @@ def test_streaming_response_custom_iterable(test_client_factory):
     assert response.text == "12345"
 
 
-def test_sync_streaming_response(test_client_factory):
-    async def app(scope, receive, send):
-        def numbers(minimum, maximum):
+def test_sync_streaming_response(test_client_factory: TestClientFactory) -> None:
+    async def app(scope: Scope, receive: Receive, send: Send) -> None:
+        def numbers(minimum: int, maximum: int) -> Iterator[str]:
             for i in range(minimum, maximum + 1):
                 yield str(i)
                 if i != maximum:
@@ -180,8 +189,8 @@ def test_sync_streaming_response(test_client_factory):
     assert response.text == "1, 2, 3, 4, 5"
 
 
-def test_response_headers(test_client_factory):
-    async def app(scope, receive, send):
+def test_response_headers(test_client_factory: TestClientFactory) -> None:
+    async def app(scope: Scope, receive: Receive, send: Send) -> None:
         headers = {"x-header-1": "123", "x-header-2": "456"}
         response = Response("hello, world", media_type="text/plain", headers=headers)
         response.headers["x-header-2"] = "789"
@@ -193,7 +202,7 @@ def test_response_headers(test_client_factory):
     assert response.headers["x-header-2"] == "789"
 
 
-def test_response_phrase(test_client_factory):
+def test_response_phrase(test_client_factory: TestClientFactory) -> None:
     app = Response(status_code=204)
     client = test_client_factory(app)
     response = client.get("/")
@@ -205,7 +214,7 @@ def test_response_phrase(test_client_factory):
     assert response.reason_phrase == ""
 
 
-def test_file_response(tmpdir, test_client_factory):
+def test_file_response(tmpdir: Path, test_client_factory: TestClientFactory) -> None:
     path = os.path.join(tmpdir, "xyz")
     content = b"<file content>" * 1000
     with open(path, "wb") as file:
@@ -213,21 +222,21 @@ def test_file_response(tmpdir, test_client_factory):
 
     filled_by_bg_task = ""
 
-    async def numbers(minimum, maximum):
+    async def numbers(minimum: int, maximum: int) -> AsyncIterator[str]:
         for i in range(minimum, maximum + 1):
             yield str(i)
             if i != maximum:
                 yield ", "
             await anyio.sleep(0)
 
-    async def numbers_for_cleanup(start=1, stop=5):
+    async def numbers_for_cleanup(start: int = 1, stop: int = 5) -> None:
         nonlocal filled_by_bg_task
         async for thing in numbers(start, stop):
             filled_by_bg_task = filled_by_bg_task + thing
 
     cleanup_task = BackgroundTask(numbers_for_cleanup, start=6, stop=9)
 
-    async def app(scope, receive, send):
+    async def app(scope: Scope, receive: Receive, send: Send) -> None:
         response = FileResponse(
             path=path, filename="example.png", background=cleanup_task
         )
@@ -248,7 +257,7 @@ def test_file_response(tmpdir, test_client_factory):
 
 
 @pytest.mark.anyio
-async def test_file_response_on_head_method(tmpdir: Path):
+async def test_file_response_on_head_method(tmpdir: Path) -> None:
     path = os.path.join(tmpdir, "xyz")
     content = b"<file content>" * 1000
     with open(path, "wb") as file:
@@ -277,7 +286,9 @@ async def test_file_response_on_head_method(tmpdir: Path):
     await app({"type": "http", "method": "head"}, receive, send)
 
 
-def test_file_response_with_directory_raises_error(tmpdir, test_client_factory):
+def test_file_response_with_directory_raises_error(
+    tmpdir: Path, test_client_factory: TestClientFactory
+) -> None:
     app = FileResponse(path=tmpdir, filename="example.png")
     client = test_client_factory(app)
     with pytest.raises(RuntimeError) as exc_info:
@@ -285,7 +296,9 @@ def test_file_response_with_directory_raises_error(tmpdir, test_client_factory):
     assert "is not a file" in str(exc_info.value)
 
 
-def test_file_response_with_missing_file_raises_error(tmpdir, test_client_factory):
+def test_file_response_with_missing_file_raises_error(
+    tmpdir: Path, test_client_factory: TestClientFactory
+) -> None:
     path = os.path.join(tmpdir, "404.txt")
     app = FileResponse(path=path, filename="404.txt")
     client = test_client_factory(app)
@@ -294,7 +307,9 @@ def test_file_response_with_missing_file_raises_error(tmpdir, test_client_factor
     assert "does not exist" in str(exc_info.value)
 
 
-def test_file_response_with_chinese_filename(tmpdir, test_client_factory):
+def test_file_response_with_chinese_filename(
+    tmpdir: Path, test_client_factory: TestClientFactory
+) -> None:
     content = b"file content"
     filename = "你好.txt"  # probably "Hello.txt" in Chinese
     path = os.path.join(tmpdir, filename)
@@ -309,7 +324,9 @@ def test_file_response_with_chinese_filename(tmpdir, test_client_factory):
     assert response.headers["content-disposition"] == expected_disposition
 
 
-def test_file_response_with_inline_disposition(tmpdir, test_client_factory):
+def test_file_response_with_inline_disposition(
+    tmpdir: Path, test_client_factory: TestClientFactory
+) -> None:
     content = b"file content"
     filename = "hello.txt"
     path = os.path.join(tmpdir, filename)
@@ -324,13 +341,15 @@ def test_file_response_with_inline_disposition(tmpdir, test_client_factory):
     assert response.headers["content-disposition"] == expected_disposition
 
 
-def test_file_response_with_method_warns(tmpdir, test_client_factory):
+def test_file_response_with_method_warns(
+    tmpdir: Path, test_client_factory: TestClientFactory
+) -> None:
     with pytest.warns(DeprecationWarning):
         FileResponse(path=tmpdir, filename="example.png", method="GET")
 
 
 @pytest.mark.anyio
-async def test_file_response_with_pathsend(tmpdir: Path):
+async def test_file_response_with_pathsend(tmpdir: Path) -> None:
     path = os.path.join(tmpdir, "xyz")
     content = b"<file content>" * 1000
     with open(path, "wb") as file:
@@ -361,12 +380,14 @@ async def test_file_response_with_pathsend(tmpdir: Path):
     )
 
 
-def test_set_cookie(test_client_factory, monkeypatch):
+def test_set_cookie(
+    test_client_factory: TestClientFactory, monkeypatch: pytest.MonkeyPatch
+) -> None:
     # Mock time used as a reference for `Expires` by stdlib `SimpleCookie`.
     mocked_now = dt.datetime(2037, 1, 22, 12, 0, 0, tzinfo=dt.timezone.utc)
     monkeypatch.setattr(time, "time", lambda: mocked_now.timestamp())
 
-    async def app(scope, receive, send):
+    async def app(scope: Scope, receive: Receive, send: Send) -> None:
         response = Response("Hello, world!", media_type="text/plain")
         response.set_cookie(
             "mycookie",
@@ -401,12 +422,16 @@ def test_set_cookie(test_client_factory, monkeypatch):
         pytest.param(10, id="int"),
     ],
 )
-def test_expires_on_set_cookie(test_client_factory, monkeypatch, expires):
+def test_expires_on_set_cookie(
+    test_client_factory: TestClientFactory,
+    monkeypatch: pytest.MonkeyPatch,
+    expires: str,
+) -> None:
     # Mock time used as a reference for `Expires` by stdlib `SimpleCookie`.
     mocked_now = dt.datetime(2037, 1, 22, 12, 0, 0, tzinfo=dt.timezone.utc)
     monkeypatch.setattr(time, "time", lambda: mocked_now.timestamp())
 
-    async def app(scope, receive, send):
+    async def app(scope: Scope, receive: Receive, send: Send) -> None:
         response = Response("Hello, world!", media_type="text/plain")
         response.set_cookie("mycookie", "myvalue", expires=expires)
         await response(scope, receive, send)
@@ -417,8 +442,8 @@ def test_expires_on_set_cookie(test_client_factory, monkeypatch, expires):
     assert cookie["mycookie"]["expires"] == "Thu, 22 Jan 2037 12:00:10 GMT"
 
 
-def test_delete_cookie(test_client_factory):
-    async def app(scope, receive, send):
+def test_delete_cookie(test_client_factory: TestClientFactory) -> None:
+    async def app(scope: Scope, receive: Receive, send: Send) -> None:
         request = Request(scope, receive)
         response = Response("Hello, world!", media_type="text/plain")
         if request.cookies.get("mycookie"):
@@ -434,7 +459,7 @@ def test_delete_cookie(test_client_factory):
     assert not response.cookies.get("mycookie")
 
 
-def test_populate_headers(test_client_factory):
+def test_populate_headers(test_client_factory: TestClientFactory) -> None:
     app = Response(content="hi", headers={}, media_type="text/html")
     client = test_client_factory(app)
     response = client.get("/")
@@ -443,14 +468,14 @@ def test_populate_headers(test_client_factory):
     assert response.headers["content-type"] == "text/html; charset=utf-8"
 
 
-def test_head_method(test_client_factory):
+def test_head_method(test_client_factory: TestClientFactory) -> None:
     app = Response("hello, world", media_type="text/plain")
     client = test_client_factory(app)
     response = client.head("/")
     assert response.text == ""
 
 
-def test_empty_response(test_client_factory):
+def test_empty_response(test_client_factory: TestClientFactory) -> None:
     app = Response()
     client: TestClient = test_client_factory(app)
     response = client.get("/")
@@ -459,28 +484,32 @@ def test_empty_response(test_client_factory):
     assert "content-type" not in response.headers
 
 
-def test_empty_204_response(test_client_factory):
+def test_empty_204_response(test_client_factory: TestClientFactory) -> None:
     app = Response(status_code=204)
     client: TestClient = test_client_factory(app)
     response = client.get("/")
     assert "content-length" not in response.headers
 
 
-def test_non_empty_response(test_client_factory):
+def test_non_empty_response(test_client_factory: TestClientFactory) -> None:
     app = Response(content="hi")
     client: TestClient = test_client_factory(app)
     response = client.get("/")
     assert response.headers["content-length"] == "2"
 
 
-def test_response_do_not_add_redundant_charset(test_client_factory):
+def test_response_do_not_add_redundant_charset(
+    test_client_factory: TestClientFactory,
+) -> None:
     app = Response(media_type="text/plain; charset=utf-8")
     client = test_client_factory(app)
     response = client.get("/")
     assert response.headers["content-type"] == "text/plain; charset=utf-8"
 
 
-def test_file_response_known_size(tmpdir, test_client_factory):
+def test_file_response_known_size(
+    tmpdir: Path, test_client_factory: TestClientFactory
+) -> None:
     path = os.path.join(tmpdir, "xyz")
     content = b"<file content>" * 1000
     with open(path, "wb") as file:
@@ -492,14 +521,16 @@ def test_file_response_known_size(tmpdir, test_client_factory):
     assert response.headers["content-length"] == str(len(content))
 
 
-def test_streaming_response_unknown_size(test_client_factory):
+def test_streaming_response_unknown_size(
+    test_client_factory: TestClientFactory,
+) -> None:
     app = StreamingResponse(content=iter(["hello", "world"]))
     client: TestClient = test_client_factory(app)
     response = client.get("/")
     assert "content-length" not in response.headers
 
 
-def test_streaming_response_known_size(test_client_factory):
+def test_streaming_response_known_size(test_client_factory: TestClientFactory) -> None:
     app = StreamingResponse(
         content=iter(["hello", "world"]), headers={"content-length": "10"}
     )
@@ -509,16 +540,16 @@ def test_streaming_response_known_size(test_client_factory):
 
 
 @pytest.mark.anyio
-async def test_streaming_response_stops_if_receiving_http_disconnect():
+async def test_streaming_response_stops_if_receiving_http_disconnect() -> None:
     streamed = 0
 
     disconnected = anyio.Event()
 
-    async def receive_disconnect():
+    async def receive_disconnect() -> Message:
         await disconnected.wait()
         return {"type": "http.disconnect"}
 
-    async def send(message):
+    async def send(message: Message) -> None:
         nonlocal streamed
         if message["type"] == "http.response.body":
             streamed += len(message.get("body", b""))
@@ -526,7 +557,7 @@ async def test_streaming_response_stops_if_receiving_http_disconnect():
             if streamed >= 16:
                 disconnected.set()
 
-    async def stream_indefinitely():
+    async def stream_indefinitely() -> AsyncIterator[bytes]:
         while True:
             # Need a sleep for the event loop to switch to another task
             await anyio.sleep(0)