From f2e9115c0596519141b201a1025311d80e9a8730 Mon Sep 17 00:00:00 2001 From: jaceksnet Date: Mon, 24 Jun 2019 19:59:09 +0200 Subject: [PATCH] allow use of characters out of latin-1 in multipart form values and filename --- starlette/formparsers.py | 14 ++++++++-- tests/test_formparsers.py | 55 ++++++++++++++++++++++++++++++++++++++- 2 files changed, 66 insertions(+), 3 deletions(-) diff --git a/starlette/formparsers.py b/starlette/formparsers.py index cbbe8949..8479e0d1 100644 --- a/starlette/formparsers.py +++ b/starlette/formparsers.py @@ -31,6 +31,13 @@ class MultiPartMessage(Enum): END = 8 +def _user_safe_decode(src: bytes, codec: str) -> str: + try: + return src.decode(codec) + except LookupError: + return src.decode("latin-1") + + class FormParser: def __init__( self, headers: Headers, stream: typing.AsyncGenerator[bytes, None] @@ -153,6 +160,9 @@ class MultiPartParser: async def parse(self) -> FormData: # Parse the Content-Type header to get the multipart boundary. content_type, params = parse_options_header(self.headers["Content-Type"]) + charset = params.get(b"charset", "latin-1") + if type(charset) == bytes: + charset = charset.decode("latin-1") boundary = params.get(b"boundary") # Callbacks dictionary. @@ -204,7 +214,7 @@ class MultiPartParser: disposition, options = parse_options_header(content_disposition) field_name = options[b"name"].decode("latin-1") if b"filename" in options: - filename = options[b"filename"].decode("latin-1") + filename = _user_safe_decode(options[b"filename"], charset) file = UploadFile(filename=filename, content_type=content_type) else: file = None @@ -215,7 +225,7 @@ class MultiPartParser: await file.write(message_bytes) elif message_type == MultiPartMessage.PART_END: if file is None: - items.append((field_name, data.decode("latin-1"))) + items.append((field_name, _user_safe_decode(data, charset))) else: await file.seek(0) items.append((field_name, file)) diff --git a/tests/test_formparsers.py b/tests/test_formparsers.py index 8bc71905..c4fedeed 100644 --- a/tests/test_formparsers.py +++ b/tests/test_formparsers.py @@ -1,6 +1,6 @@ import os -from starlette.formparsers import UploadFile +from starlette.formparsers import UploadFile, _user_safe_decode from starlette.requests import Request from starlette.responses import JSONResponse from starlette.testclient import TestClient @@ -206,6 +206,49 @@ def test_multipart_request_mixed_files_and_data(tmpdir): } +def test_multipart_request_with_charset_for_filename(tmpdir): + client = TestClient(app) + response = client.post( + "/", + data=( + # file + b"--a7f7ac8d4e2e437c877bb7b8d7cc549c\r\n" + b'Content-Disposition: form-data; name="file"; filename="file\xc4\x85.txt"\r\n' + b"Content-Type: text/plain\r\n\r\n" + b"\r\n" + b"--a7f7ac8d4e2e437c877bb7b8d7cc549c--\r\n" + ), + headers={ + "Content-Type": "multipart/form-data; charset=utf-8; boundary=a7f7ac8d4e2e437c877bb7b8d7cc549c" + }, + ) + assert response.json() == { + "file": { + "filename": "fileą.txt", + "content": "", + "content_type": "text/plain", + } + } + + +def test_multipart_request_with_encoded_value(tmpdir): + client = TestClient(app) + response = client.post( + "/", + data=( + b"--20b303e711c4ab8c443184ac833ab00f\r\n" + b"Content-Disposition: form-data; " + b'name="value"\r\n\r\n' + b"Transf\xc3\xa9rer\r\n" + b"--20b303e711c4ab8c443184ac833ab00f--\r\n" + ), + headers={ + "Content-Type": "multipart/form-data; charset=utf-8; boundary=20b303e711c4ab8c443184ac833ab00f" + }, + ) + assert response.json() == {"value": "Transférer"} + + def test_urlencoded_request_data(tmpdir): client = TestClient(app) response = client.post("/", data={"some": "data"}) @@ -242,3 +285,13 @@ def test_multipart_multi_field_app_reads_body(tmpdir): "/", data={"some": "data", "second": "key pair"}, files=FORCE_MULTIPART ) assert response.json() == {"some": "data", "second": "key pair"} + + +def test_user_safe_decode_helper(): + result = _user_safe_decode(b"\xc4\x99\xc5\xbc\xc4\x87", "utf-8") + assert result == "ężć" + + +def test_user_safe_decode_ignores_wrong_charset(): + result = _user_safe_decode(b"abc", "latin-8") + assert result == "abc" -- 2.47.2