]> git.ipfire.org Git - thirdparty/starlette.git/commitdiff
Add type hints to `test_formparsers.py` (#2480)
authorScirlat Danut <danut.scirlat@gmail.com>
Tue, 6 Feb 2024 22:01:44 +0000 (00:01 +0200)
committerGitHub <noreply@github.com>
Tue, 6 Feb 2024 22:01:44 +0000 (22:01 +0000)
* added type annotations to test_formparsers.py

* Apply suggestions from code review

* Apply suggestions from code review

* Apply suggestions from code review

---------

Co-authored-by: Scirlat Danut <scirlatdanut@scirlats-mini.lan>
Co-authored-by: Marcelo Trylesinski <marcelotryle@gmail.com>
tests/test_formparsers.py

index 77ed776eaaa28d6fae7737801b217ff375629a39..4f0cd430d35b7c625ee01b691029fc2a7a7bc1f1 100644 (file)
@@ -1,6 +1,7 @@
 import os
 import typing
 from contextlib import nullcontext as does_not_raise
+from pathlib import Path
 
 import pytest
 
@@ -10,10 +11,14 @@ from starlette.formparsers import MultiPartException, _user_safe_decode
 from starlette.requests import Request
 from starlette.responses import JSONResponse
 from starlette.routing import Mount
+from starlette.testclient import TestClient
+from starlette.types import ASGIApp, Receive, Scope, Send
+
+TestClientFactory = typing.Callable[..., TestClient]
 
 
 class ForceMultipartDict(typing.Dict[typing.Any, typing.Any]):
-    def __bool__(self):
+    def __bool__(self) -> bool:
         return True
 
 
@@ -21,7 +26,7 @@ class ForceMultipartDict(typing.Dict[typing.Any, typing.Any]):
 FORCE_MULTIPART = ForceMultipartDict()
 
 
-async def app(scope, receive, send):
+async def app(scope: Scope, receive: Receive, send: Send) -> None:
     request = Request(scope, receive)
     data = await request.form()
     output: typing.Dict[str, typing.Any] = {}
@@ -41,7 +46,7 @@ async def app(scope, receive, send):
     await response(scope, receive, send)
 
 
-async def multi_items_app(scope, receive, send):
+async def multi_items_app(scope: Scope, receive: Receive, send: Send) -> None:
     request = Request(scope, receive)
     data = await request.form()
     output: typing.Dict[str, typing.List[typing.Any]] = {}
@@ -65,7 +70,7 @@ async def multi_items_app(scope, receive, send):
     await response(scope, receive, send)
 
 
-async def app_with_headers(scope, receive, send):
+async def app_with_headers(scope: Scope, receive: Receive, send: Send) -> None:
     request = Request(scope, receive)
     data = await request.form()
     output: typing.Dict[str, typing.Any] = {}
@@ -86,7 +91,7 @@ async def app_with_headers(scope, receive, send):
     await response(scope, receive, send)
 
 
-async def app_read_body(scope, receive, send):
+async def app_read_body(scope: Scope, receive: Receive, send: Send) -> None:
     request = Request(scope, receive)
     # Read bytes, to force request.stream() to return the already parsed body
     await request.body()
@@ -99,8 +104,8 @@ async def app_read_body(scope, receive, send):
     await response(scope, receive, send)
 
 
-def make_app_max_parts(max_files: int = 1000, max_fields: int = 1000):
-    async def app(scope, receive, send):
+def make_app_max_parts(max_files: int = 1000, max_fields: int = 1000) -> ASGIApp:
+    async def app(scope: Scope, receive: Receive, send: Send) -> None:
         request = Request(scope, receive)
         data = await request.form(max_files=max_files, max_fields=max_fields)
         output: typing.Dict[str, typing.Any] = {}
@@ -122,13 +127,17 @@ def make_app_max_parts(max_files: int = 1000, max_fields: int = 1000):
     return app
 
 
-def test_multipart_request_data(tmpdir, test_client_factory):
+def test_multipart_request_data(
+    tmpdir: Path, test_client_factory: TestClientFactory
+) -> None:
     client = test_client_factory(app)
     response = client.post("/", data={"some": "data"}, files=FORCE_MULTIPART)
     assert response.json() == {"some": "data"}
 
 
-def test_multipart_request_files(tmpdir, test_client_factory):
+def test_multipart_request_files(
+    tmpdir: Path, test_client_factory: TestClientFactory
+) -> None:
     path = os.path.join(tmpdir, "test.txt")
     with open(path, "wb") as file:
         file.write(b"<file content>")
@@ -146,7 +155,9 @@ def test_multipart_request_files(tmpdir, test_client_factory):
         }
 
 
-def test_multipart_request_files_with_content_type(tmpdir, test_client_factory):
+def test_multipart_request_files_with_content_type(
+    tmpdir: Path, test_client_factory: TestClientFactory
+) -> None:
     path = os.path.join(tmpdir, "test.txt")
     with open(path, "wb") as file:
         file.write(b"<file content>")
@@ -164,7 +175,9 @@ def test_multipart_request_files_with_content_type(tmpdir, test_client_factory):
         }
 
 
-def test_multipart_request_multiple_files(tmpdir, test_client_factory):
+def test_multipart_request_multiple_files(
+    tmpdir: Path, test_client_factory: TestClientFactory
+) -> None:
     path1 = os.path.join(tmpdir, "test1.txt")
     with open(path1, "wb") as file:
         file.write(b"<file1 content>")
@@ -194,7 +207,9 @@ def test_multipart_request_multiple_files(tmpdir, test_client_factory):
         }
 
 
-def test_multipart_request_multiple_files_with_headers(tmpdir, test_client_factory):
+def test_multipart_request_multiple_files_with_headers(
+    tmpdir: Path, test_client_factory: TestClientFactory
+) -> None:
     path1 = os.path.join(tmpdir, "test1.txt")
     with open(path1, "wb") as file:
         file.write(b"<file1 content>")
@@ -231,7 +246,7 @@ def test_multipart_request_multiple_files_with_headers(tmpdir, test_client_facto
         }
 
 
-def test_multi_items(tmpdir, test_client_factory):
+def test_multi_items(tmpdir: Path, test_client_factory: TestClientFactory) -> None:
     path1 = os.path.join(tmpdir, "test1.txt")
     with open(path1, "wb") as file:
         file.write(b"<file1 content>")
@@ -266,13 +281,15 @@ def test_multi_items(tmpdir, test_client_factory):
         }
 
 
-def test_multipart_request_mixed_files_and_data(tmpdir, test_client_factory):
+def test_multipart_request_mixed_files_and_data(
+    tmpdir: Path, test_client_factory: TestClientFactory
+) -> None:
     client = test_client_factory(app)
     response = client.post(
         "/",
         data=(
             # data
-            b"--a7f7ac8d4e2e437c877bb7b8d7cc549c\r\n"
+            b"--a7f7ac8d4e2e437c877bb7b8d7cc549c\r\n"  # type: ignore
             b'Content-Disposition: form-data; name="field0"\r\n\r\n'
             b"value0\r\n"
             # file
@@ -304,13 +321,15 @@ def test_multipart_request_mixed_files_and_data(tmpdir, test_client_factory):
     }
 
 
-def test_multipart_request_with_charset_for_filename(tmpdir, test_client_factory):
+def test_multipart_request_with_charset_for_filename(
+    tmpdir: Path, test_client_factory: TestClientFactory
+) -> None:
     client = test_client_factory(app)
     response = client.post(
         "/",
         data=(
             # file
-            b"--a7f7ac8d4e2e437c877bb7b8d7cc549c\r\n"
+            b"--a7f7ac8d4e2e437c877bb7b8d7cc549c\r\n"  # type: ignore
             b'Content-Disposition: form-data; name="file"; filename="\xe6\x96\x87\xe6\x9b\xb8.txt"\r\n'  # noqa: E501
             b"Content-Type: text/plain\r\n\r\n"
             b"<file content>\r\n"
@@ -333,13 +352,15 @@ def test_multipart_request_with_charset_for_filename(tmpdir, test_client_factory
     }
 
 
-def test_multipart_request_without_charset_for_filename(tmpdir, test_client_factory):
+def test_multipart_request_without_charset_for_filename(
+    tmpdir: Path, test_client_factory: TestClientFactory
+) -> None:
     client = test_client_factory(app)
     response = client.post(
         "/",
         data=(
             # file
-            b"--a7f7ac8d4e2e437c877bb7b8d7cc549c\r\n"
+            b"--a7f7ac8d4e2e437c877bb7b8d7cc549c\r\n"  # type: ignore
             b'Content-Disposition: form-data; name="file"; filename="\xe7\x94\xbb\xe5\x83\x8f.jpg"\r\n'  # noqa: E501
             b"Content-Type: image/jpeg\r\n\r\n"
             b"<file content>\r\n"
@@ -361,12 +382,14 @@ def test_multipart_request_without_charset_for_filename(tmpdir, test_client_fact
     }
 
 
-def test_multipart_request_with_encoded_value(tmpdir, test_client_factory):
+def test_multipart_request_with_encoded_value(
+    tmpdir: Path, test_client_factory: TestClientFactory
+) -> None:
     client = test_client_factory(app)
     response = client.post(
         "/",
         data=(
-            b"--20b303e711c4ab8c443184ac833ab00f\r\n"
+            b"--20b303e711c4ab8c443184ac833ab00f\r\n"  # type: ignore
             b"Content-Disposition: form-data; "
             b'name="value"\r\n\r\n'
             b"Transf\xc3\xa9rer\r\n"
@@ -382,37 +405,47 @@ def test_multipart_request_with_encoded_value(tmpdir, test_client_factory):
     assert response.json() == {"value": "Transférer"}
 
 
-def test_urlencoded_request_data(tmpdir, test_client_factory):
+def test_urlencoded_request_data(
+    tmpdir: Path, test_client_factory: TestClientFactory
+) -> None:
     client = test_client_factory(app)
     response = client.post("/", data={"some": "data"})
     assert response.json() == {"some": "data"}
 
 
-def test_no_request_data(tmpdir, test_client_factory):
+def test_no_request_data(tmpdir: Path, test_client_factory: TestClientFactory) -> None:
     client = test_client_factory(app)
     response = client.post("/")
     assert response.json() == {}
 
 
-def test_urlencoded_percent_encoding(tmpdir, test_client_factory):
+def test_urlencoded_percent_encoding(
+    tmpdir: Path, test_client_factory: TestClientFactory
+) -> None:
     client = test_client_factory(app)
     response = client.post("/", data={"some": "da ta"})
     assert response.json() == {"some": "da ta"}
 
 
-def test_urlencoded_percent_encoding_keys(tmpdir, test_client_factory):
+def test_urlencoded_percent_encoding_keys(
+    tmpdir: Path, test_client_factory: TestClientFactory
+) -> None:
     client = test_client_factory(app)
     response = client.post("/", data={"so me": "data"})
     assert response.json() == {"so me": "data"}
 
 
-def test_urlencoded_multi_field_app_reads_body(tmpdir, test_client_factory):
+def test_urlencoded_multi_field_app_reads_body(
+    tmpdir: Path, test_client_factory: TestClientFactory
+) -> None:
     client = test_client_factory(app_read_body)
     response = client.post("/", data={"some": "data", "second": "key pair"})
     assert response.json() == {"some": "data", "second": "key pair"}
 
 
-def test_multipart_multi_field_app_reads_body(tmpdir, test_client_factory):
+def test_multipart_multi_field_app_reads_body(
+    tmpdir: Path, test_client_factory: TestClientFactory
+) -> None:
     client = test_client_factory(app_read_body)
     response = client.post(
         "/", data={"some": "data", "second": "key pair"}, files=FORCE_MULTIPART
@@ -420,12 +453,12 @@ def test_multipart_multi_field_app_reads_body(tmpdir, test_client_factory):
     assert response.json() == {"some": "data", "second": "key pair"}
 
 
-def test_user_safe_decode_helper():
+def test_user_safe_decode_helper() -> None:
     result = _user_safe_decode(b"\xc4\x99\xc5\xbc\xc4\x87", "utf-8")
     assert result == "ężć"
 
 
-def test_user_safe_decode_ignores_wrong_charset():
+def test_user_safe_decode_ignores_wrong_charset() -> None:
     result = _user_safe_decode(b"abc", "latin-8")
     assert result == "abc"
 
@@ -437,14 +470,18 @@ def test_user_safe_decode_ignores_wrong_charset():
         (Starlette(routes=[Mount("/", app=app)]), does_not_raise()),
     ],
 )
-def test_missing_boundary_parameter(app, expectation, test_client_factory) -> None:
+def test_missing_boundary_parameter(
+    app: ASGIApp,
+    expectation: typing.ContextManager[Exception],
+    test_client_factory: TestClientFactory,
+) -> None:
     client = test_client_factory(app)
     with expectation:
         res = client.post(
             "/",
             data=(
                 # file
-                b'Content-Disposition: form-data; name="file"; filename="\xe6\x96\x87\xe6\x9b\xb8.txt"\r\n'  # noqa: E501
+                b'Content-Disposition: form-data; name="file"; filename="\xe6\x96\x87\xe6\x9b\xb8.txt"\r\n'  # type: ignore # noqa: E501
                 b"Content-Type: text/plain\r\n\r\n"
                 b"<file content>\r\n"
             ),
@@ -462,15 +499,17 @@ def test_missing_boundary_parameter(app, expectation, test_client_factory) -> No
     ],
 )
 def test_missing_name_parameter_on_content_disposition(
-    app, expectation, test_client_factory
-):
+    app: ASGIApp,
+    expectation: typing.ContextManager[Exception],
+    test_client_factory: TestClientFactory,
+) -> None:
     client = test_client_factory(app)
     with expectation:
         res = client.post(
             "/",
             data=(
                 # data
-                b"--a7f7ac8d4e2e437c877bb7b8d7cc549c\r\n"
+                b"--a7f7ac8d4e2e437c877bb7b8d7cc549c\r\n"  # type: ignore
                 b'Content-Disposition: form-data; ="field0"\r\n\r\n'
                 b"value0\r\n"
             ),
@@ -493,7 +532,11 @@ def test_missing_name_parameter_on_content_disposition(
         (Starlette(routes=[Mount("/", app=app)]), does_not_raise()),
     ],
 )
-def test_too_many_fields_raise(app, expectation, test_client_factory):
+def test_too_many_fields_raise(
+    app: ASGIApp,
+    expectation: typing.ContextManager[Exception],
+    test_client_factory: TestClientFactory,
+) -> None:
     client = test_client_factory(app)
     fields = []
     for i in range(1001):
@@ -504,7 +547,7 @@ def test_too_many_fields_raise(app, expectation, test_client_factory):
     with expectation:
         res = client.post(
             "/",
-            data=data,
+            data=data,  # type: ignore
             headers={"Content-Type": ("multipart/form-data; boundary=B")},
         )
         assert res.status_code == 400
@@ -518,7 +561,11 @@ def test_too_many_fields_raise(app, expectation, test_client_factory):
         (Starlette(routes=[Mount("/", app=app)]), does_not_raise()),
     ],
 )
-def test_too_many_files_raise(app, expectation, test_client_factory):
+def test_too_many_files_raise(
+    app: ASGIApp,
+    expectation: typing.ContextManager[Exception],
+    test_client_factory: TestClientFactory,
+) -> None:
     client = test_client_factory(app)
     fields = []
     for i in range(1001):
@@ -531,7 +578,7 @@ def test_too_many_files_raise(app, expectation, test_client_factory):
     with expectation:
         res = client.post(
             "/",
-            data=data,
+            data=data,  # type: ignore
             headers={"Content-Type": ("multipart/form-data; boundary=B")},
         )
         assert res.status_code == 400
@@ -545,7 +592,11 @@ def test_too_many_files_raise(app, expectation, test_client_factory):
         (Starlette(routes=[Mount("/", app=app)]), does_not_raise()),
     ],
 )
-def test_too_many_files_single_field_raise(app, expectation, test_client_factory):
+def test_too_many_files_single_field_raise(
+    app: ASGIApp,
+    expectation: typing.ContextManager[Exception],
+    test_client_factory: TestClientFactory,
+) -> None:
     client = test_client_factory(app)
     fields = []
     for i in range(1001):
@@ -560,7 +611,7 @@ def test_too_many_files_single_field_raise(app, expectation, test_client_factory
     with expectation:
         res = client.post(
             "/",
-            data=data,
+            data=data,  # type: ignore
             headers={"Content-Type": ("multipart/form-data; boundary=B")},
         )
         assert res.status_code == 400
@@ -574,7 +625,11 @@ def test_too_many_files_single_field_raise(app, expectation, test_client_factory
         (Starlette(routes=[Mount("/", app=app)]), does_not_raise()),
     ],
 )
-def test_too_many_files_and_fields_raise(app, expectation, test_client_factory):
+def test_too_many_files_and_fields_raise(
+    app: ASGIApp,
+    expectation: typing.ContextManager[Exception],
+    test_client_factory: TestClientFactory,
+) -> None:
     client = test_client_factory(app)
     fields = []
     for i in range(1001):
@@ -590,7 +645,7 @@ def test_too_many_files_and_fields_raise(app, expectation, test_client_factory):
     with expectation:
         res = client.post(
             "/",
-            data=data,
+            data=data,  # type: ignore
             headers={"Content-Type": ("multipart/form-data; boundary=B")},
         )
         assert res.status_code == 400
@@ -607,7 +662,11 @@ def test_too_many_files_and_fields_raise(app, expectation, test_client_factory):
         ),
     ],
 )
-def test_max_fields_is_customizable_low_raises(app, expectation, test_client_factory):
+def test_max_fields_is_customizable_low_raises(
+    app: ASGIApp,
+    expectation: typing.ContextManager[Exception],
+    test_client_factory: TestClientFactory,
+) -> None:
     client = test_client_factory(app)
     fields = []
     for i in range(2):
@@ -618,7 +677,7 @@ def test_max_fields_is_customizable_low_raises(app, expectation, test_client_fac
     with expectation:
         res = client.post(
             "/",
-            data=data,
+            data=data,  # type: ignore
             headers={"Content-Type": ("multipart/form-data; boundary=B")},
         )
         assert res.status_code == 400
@@ -635,7 +694,11 @@ def test_max_fields_is_customizable_low_raises(app, expectation, test_client_fac
         ),
     ],
 )
-def test_max_files_is_customizable_low_raises(app, expectation, test_client_factory):
+def test_max_files_is_customizable_low_raises(
+    app: ASGIApp,
+    expectation: typing.ContextManager[Exception],
+    test_client_factory: TestClientFactory,
+) -> None:
     client = test_client_factory(app)
     fields = []
     for i in range(2):
@@ -648,14 +711,16 @@ def test_max_files_is_customizable_low_raises(app, expectation, test_client_fact
     with expectation:
         res = client.post(
             "/",
-            data=data,
+            data=data,  # type: ignore
             headers={"Content-Type": ("multipart/form-data; boundary=B")},
         )
         assert res.status_code == 400
         assert res.text == "Too many files. Maximum number of files is 1."
 
 
-def test_max_fields_is_customizable_high(test_client_factory):
+def test_max_fields_is_customizable_high(
+    test_client_factory: TestClientFactory,
+) -> None:
     client = test_client_factory(make_app_max_parts(max_fields=2000, max_files=2000))
     fields = []
     for i in range(2000):
@@ -671,7 +736,7 @@ def test_max_fields_is_customizable_high(test_client_factory):
     data += b"--B--\r\n"
     res = client.post(
         "/",
-        data=data,
+        data=data,  # type: ignore
         headers={"Content-Type": ("multipart/form-data; boundary=B")},
     )
     assert res.status_code == 200