]> git.ipfire.org Git - thirdparty/starlette.git/commitdiff
Add type hints to `test_requests.py` (#2481)
authorScirlat Danut <danut.scirlat@gmail.com>
Tue, 6 Feb 2024 20:18:43 +0000 (22:18 +0200)
committerGitHub <noreply@github.com>
Tue, 6 Feb 2024 20:18:43 +0000 (13:18 -0700)
* added type annotations to test_requests.py

* requested changes

* indentations

* typos

* typos

* Apply suggestions from code review

* Apply suggestions from code review

---------

Co-authored-by: Scirlat Danut <scirlatdanut@scirlats-mini.lan>
Co-authored-by: Marcelo Trylesinski <marcelotryle@gmail.com>
tests/test_requests.py

index caf110efe37a30c9476ea5ad8b885feaf7638003..b3ce3a04add67609a026d5bf45904c69225e03fb 100644 (file)
@@ -1,5 +1,5 @@
 import sys
-from typing import List, Optional
+from typing import Any, Callable, Dict, Iterator, List, Optional
 
 import anyio
 import pytest
@@ -7,11 +7,14 @@ import pytest
 from starlette.datastructures import Address, State
 from starlette.requests import ClientDisconnect, Request
 from starlette.responses import JSONResponse, PlainTextResponse, Response
-from starlette.types import Message, Scope
+from starlette.testclient import TestClient
+from starlette.types import Message, Receive, Scope, Send
 
+TestClientFactory = Callable[..., TestClient]
 
-def test_request_url(test_client_factory):
-    async def app(scope, receive, send):
+
+def test_request_url(test_client_factory: TestClientFactory) -> None:
+    async def app(scope: Scope, receive: Receive, send: Send) -> None:
         request = Request(scope, receive)
         data = {"method": request.method, "url": str(request.url)}
         response = JSONResponse(data)
@@ -25,8 +28,8 @@ def test_request_url(test_client_factory):
     assert response.json() == {"method": "GET", "url": "https://example.org:123/"}
 
 
-def test_request_query_params(test_client_factory):
-    async def app(scope, receive, send):
+def test_request_query_params(test_client_factory: TestClientFactory) -> None:
+    async def app(scope: Scope, receive: Receive, send: Send) -> None:
         request = Request(scope, receive)
         params = dict(request.query_params)
         response = JSONResponse({"params": params})
@@ -41,8 +44,8 @@ def test_request_query_params(test_client_factory):
     any(module in sys.modules for module in ("brotli", "brotlicffi")),
     reason='urllib3 includes "br" to the "accept-encoding" headers.',
 )
-def test_request_headers(test_client_factory):
-    async def app(scope, receive, send):
+def test_request_headers(test_client_factory: TestClientFactory) -> None:
+    async def app(scope: Scope, receive: Receive, send: Send) -> None:
         request = Request(scope, receive)
         headers = dict(request.headers)
         response = JSONResponse({"headers": headers})
@@ -69,14 +72,14 @@ def test_request_headers(test_client_factory):
         ({}, None),
     ],
 )
-def test_request_client(scope: Scope, expected_client: Optional[Address]):
+def test_request_client(scope: Scope, expected_client: Optional[Address]) -> None:
     scope.update({"type": "http"})  # required by Request's constructor
     client = Request(scope).client
     assert client == expected_client
 
 
-def test_request_body(test_client_factory):
-    async def app(scope, receive, send):
+def test_request_body(test_client_factory: TestClientFactory) -> None:
+    async def app(scope: Scope, receive: Receive, send: Send) -> None:
         request = Request(scope, receive)
         body = await request.body()
         response = JSONResponse({"body": body.decode()})
@@ -90,12 +93,12 @@ def test_request_body(test_client_factory):
     response = client.post("/", json={"a": "123"})
     assert response.json() == {"body": '{"a": "123"}'}
 
-    response = client.post("/", data="abc")
+    response = client.post("/", data="abc")  # type: ignore
     assert response.json() == {"body": "abc"}
 
 
-def test_request_stream(test_client_factory):
-    async def app(scope, receive, send):
+def test_request_stream(test_client_factory: TestClientFactory) -> None:
+    async def app(scope: Scope, receive: Receive, send: Send) -> None:
         request = Request(scope, receive)
         body = b""
         async for chunk in request.stream():
@@ -111,12 +114,12 @@ def test_request_stream(test_client_factory):
     response = client.post("/", json={"a": "123"})
     assert response.json() == {"body": '{"a": "123"}'}
 
-    response = client.post("/", data="abc")
+    response = client.post("/", data="abc")  # type: ignore
     assert response.json() == {"body": "abc"}
 
 
-def test_request_form_urlencoded(test_client_factory):
-    async def app(scope, receive, send):
+def test_request_form_urlencoded(test_client_factory: TestClientFactory) -> None:
+    async def app(scope: Scope, receive: Receive, send: Send) -> None:
         request = Request(scope, receive)
         form = await request.form()
         response = JSONResponse({"form": dict(form)})
@@ -128,8 +131,8 @@ def test_request_form_urlencoded(test_client_factory):
     assert response.json() == {"form": {"abc": "123 @"}}
 
 
-def test_request_form_context_manager(test_client_factory):
-    async def app(scope, receive, send):
+def test_request_form_context_manager(test_client_factory: TestClientFactory) -> None:
+    async def app(scope: Scope, receive: Receive, send: Send) -> None:
         request = Request(scope, receive)
         async with request.form() as form:
             response = JSONResponse({"form": dict(form)})
@@ -141,8 +144,8 @@ def test_request_form_context_manager(test_client_factory):
     assert response.json() == {"form": {"abc": "123 @"}}
 
 
-def test_request_body_then_stream(test_client_factory):
-    async def app(scope, receive, send):
+def test_request_body_then_stream(test_client_factory: TestClientFactory) -> None:
+    async def app(scope: Scope, receive: Receive, send: Send) -> None:
         request = Request(scope, receive)
         body = await request.body()
         chunks = b""
@@ -153,12 +156,12 @@ def test_request_body_then_stream(test_client_factory):
 
     client = test_client_factory(app)
 
-    response = client.post("/", data="abc")
+    response = client.post("/", data="abc")  # type: ignore
     assert response.json() == {"body": "abc", "stream": "abc"}
 
 
-def test_request_stream_then_body(test_client_factory):
-    async def app(scope, receive, send):
+def test_request_stream_then_body(test_client_factory: TestClientFactory) -> None:
+    async def app(scope: Scope, receive: Receive, send: Send) -> None:
         request = Request(scope, receive)
         chunks = b""
         async for chunk in request.stream():
@@ -172,12 +175,12 @@ def test_request_stream_then_body(test_client_factory):
 
     client = test_client_factory(app)
 
-    response = client.post("/", data="abc")
+    response = client.post("/", data="abc")  # type: ignore
     assert response.json() == {"body": "<stream consumed>", "stream": "abc"}
 
 
-def test_request_json(test_client_factory):
-    async def app(scope, receive, send):
+def test_request_json(test_client_factory: TestClientFactory) -> None:
+    async def app(scope: Scope, receive: Receive, send: Send) -> None:
         request = Request(scope, receive)
         data = await request.json()
         response = JSONResponse({"json": data})
@@ -188,7 +191,7 @@ def test_request_json(test_client_factory):
     assert response.json() == {"json": {"a": "123"}}
 
 
-def test_request_scope_interface():
+def test_request_scope_interface() -> None:
     """
     A Request can be instantiated with a scope, and presents a `Mapping`
     interface.
@@ -199,8 +202,8 @@ def test_request_scope_interface():
     assert len(request) == 3
 
 
-def test_request_raw_path(test_client_factory):
-    async def app(scope, receive, send):
+def test_request_raw_path(test_client_factory: TestClientFactory) -> None:
+    async def app(scope: Scope, receive: Receive, send: Send) -> None:
         request = Request(scope, receive)
         path = request.scope["path"]
         raw_path = request.scope["raw_path"]
@@ -212,13 +215,15 @@ def test_request_raw_path(test_client_factory):
     assert response.text == "/he/llo, b'/he%2Fllo'"
 
 
-def test_request_without_setting_receive(test_client_factory):
+def test_request_without_setting_receive(
+    test_client_factory: TestClientFactory,
+) -> None:
     """
     If Request is instantiated without the receive channel, then .body()
     is not available.
     """
 
-    async def app(scope, receive, send):
+    async def app(scope: Scope, receive: Receive, send: Send) -> None:
         request = Request(scope)
         try:
             data = await request.json()
@@ -232,23 +237,26 @@ def test_request_without_setting_receive(test_client_factory):
     assert response.json() == {"json": "Receive channel not available"}
 
 
-def test_request_disconnect(anyio_backend_name, anyio_backend_options):
+def test_request_disconnect(
+    anyio_backend_name: str,
+    anyio_backend_options: Dict[str, Any],
+) -> None:
     """
     If a client disconnect occurs while reading request body
     then ClientDisconnect should be raised.
     """
 
-    async def app(scope, receive, send):
+    async def app(scope: Scope, receive: Receive, send: Send) -> None:
         request = Request(scope, receive)
         await request.body()
 
-    async def receiver():
+    async def receiver() -> Message:
         return {"type": "http.disconnect"}
 
     scope = {"type": "http", "method": "POST", "path": "/"}
     with pytest.raises(ClientDisconnect):
         anyio.run(
-            app,
+            app,  # type: ignore
             scope,
             receiver,
             None,
@@ -257,14 +265,14 @@ def test_request_disconnect(anyio_backend_name, anyio_backend_options):
         )
 
 
-def test_request_is_disconnected(test_client_factory):
+def test_request_is_disconnected(test_client_factory: TestClientFactory) -> None:
     """
     If a client disconnect occurs while reading request body
     then ClientDisconnect should be raised.
     """
     disconnected_after_response = None
 
-    async def app(scope, receive, send):
+    async def app(scope: Scope, receive: Receive, send: Send) -> None:
         nonlocal disconnected_after_response
 
         request = Request(scope, receive)
@@ -280,7 +288,7 @@ def test_request_is_disconnected(test_client_factory):
     assert disconnected_after_response
 
 
-def test_request_state_object():
+def test_request_state_object() -> None:
     scope = {"state": {"old": "foo"}}
 
     s = State(scope["state"])
@@ -294,8 +302,8 @@ def test_request_state_object():
         s.new
 
 
-def test_request_state(test_client_factory):
-    async def app(scope, receive, send):
+def test_request_state(test_client_factory: TestClientFactory) -> None:
+    async def app(scope: Scope, receive: Receive, send: Send) -> None:
         request = Request(scope, receive)
         request.state.example = 123
         response = JSONResponse({"state.example": request.state.example})
@@ -306,8 +314,8 @@ def test_request_state(test_client_factory):
     assert response.json() == {"state.example": 123}
 
 
-def test_request_cookies(test_client_factory):
-    async def app(scope, receive, send):
+def test_request_cookies(test_client_factory: TestClientFactory) -> None:
+    async def app(scope: Scope, receive: Receive, send: Send) -> None:
         request = Request(scope, receive)
         mycookie = request.cookies.get("mycookie")
         if mycookie:
@@ -325,7 +333,7 @@ def test_request_cookies(test_client_factory):
     assert response.text == "Hello, cookies!"
 
 
-def test_cookie_lenient_parsing(test_client_factory):
+def test_cookie_lenient_parsing(test_client_factory: TestClientFactory) -> None:
     """
     The following test is based on a cookie set by Okta, a well-known authorization
     service. It turns out that it's common practice to set cookies that would be
@@ -347,7 +355,7 @@ def test_cookie_lenient_parsing(test_client_factory):
         "sessionCookie",
     }
 
-    async def app(scope, receive, send):
+    async def app(scope: Scope, receive: Receive, send: Send) -> None:
         request = Request(scope, receive)
         response = JSONResponse({"cookies": request.cookies})
         await response(scope, receive, send)
@@ -381,8 +389,12 @@ def test_cookie_lenient_parsing(test_client_factory):
         ("a=b; h=i; a=c", {"a": "c", "h": "i"}),
     ],
 )
-def test_cookies_edge_cases(set_cookie, expected, test_client_factory):
-    async def app(scope, receive, send):
+def test_cookies_edge_cases(
+    set_cookie: str,
+    expected: Dict[str, str],
+    test_client_factory: TestClientFactory,
+) -> None:
+    async def app(scope: Scope, receive: Receive, send: Send) -> None:
         request = Request(scope, receive)
         response = JSONResponse({"cookies": request.cookies})
         await response(scope, receive, send)
@@ -416,13 +428,17 @@ def test_cookies_edge_cases(set_cookie, expected, test_client_factory):
         # ("  =  b  ;  ;  =  ;   c  =  ;  ", {"": "b", "c": ""}),
     ],
 )
-def test_cookies_invalid(set_cookie, expected, test_client_factory):
+def test_cookies_invalid(
+    set_cookie: str,
+    expected: Dict[str, str],
+    test_client_factory: TestClientFactory,
+) -> None:
     """
     Cookie strings that are against the RFC6265 spec but which browsers will send if set
     via document.cookie.
     """
 
-    async def app(scope, receive, send):
+    async def app(scope: Scope, receive: Receive, send: Send) -> None:
         request = Request(scope, receive)
         response = JSONResponse({"cookies": request.cookies})
         await response(scope, receive, send)
@@ -433,8 +449,8 @@ def test_cookies_invalid(set_cookie, expected, test_client_factory):
     assert result["cookies"] == expected
 
 
-def test_chunked_encoding(test_client_factory):
-    async def app(scope, receive, send):
+def test_chunked_encoding(test_client_factory: TestClientFactory) -> None:
+    async def app(scope: Scope, receive: Receive, send: Send) -> None:
         request = Request(scope, receive)
         body = await request.body()
         response = JSONResponse({"body": body.decode()})
@@ -442,16 +458,16 @@ def test_chunked_encoding(test_client_factory):
 
     client = test_client_factory(app)
 
-    def post_body():
+    def post_body() -> Iterator[bytes]:
         yield b"foo"
         yield b"bar"
 
-    response = client.post("/", data=post_body())
+    response = client.post("/", data=post_body())  # type: ignore
     assert response.json() == {"body": "foobar"}
 
 
-def test_request_send_push_promise(test_client_factory):
-    async def app(scope, receive, send):
+def test_request_send_push_promise(test_client_factory: TestClientFactory) -> None:
+    async def app(scope: Scope, receive: Receive, send: Send) -> None:
         # the server is push-enabled
         scope["extensions"]["http.response.push"] = {}
 
@@ -466,13 +482,15 @@ def test_request_send_push_promise(test_client_factory):
     assert response.json() == {"json": "OK"}
 
 
-def test_request_send_push_promise_without_push_extension(test_client_factory):
+def test_request_send_push_promise_without_push_extension(
+    test_client_factory: TestClientFactory,
+) -> None:
     """
     If server does not support the `http.response.push` extension,
     .send_push_promise() does nothing.
     """
 
-    async def app(scope, receive, send):
+    async def app(scope: Scope, receive: Receive, send: Send) -> None:
         request = Request(scope)
         await request.send_push_promise("/style.css")
 
@@ -484,13 +502,15 @@ def test_request_send_push_promise_without_push_extension(test_client_factory):
     assert response.json() == {"json": "OK"}
 
 
-def test_request_send_push_promise_without_setting_send(test_client_factory):
+def test_request_send_push_promise_without_setting_send(
+    test_client_factory: TestClientFactory,
+) -> None:
     """
     If Request is instantiated without the send channel, then
     .send_push_promise() is not available.
     """
 
-    async def app(scope, receive, send):
+    async def app(scope: Scope, receive: Receive, send: Send) -> None:
         # the server is push-enabled
         scope["extensions"]["http.response.push"] = {}