From: Tom Christie Date: Mon, 17 Jun 2019 13:38:24 +0000 (+0100) Subject: Multipart support (#90) X-Git-Tag: 0.5.0~7 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=6a44b05e4dd062b1ee318ff472122d095b2565f9;p=thirdparty%2Fhttpx.git Multipart support (#90) * Multipart support * Test compat with 3.6 --- diff --git a/http3/api.py b/http3/api.py index 4b77e3bd..99d60128 100644 --- a/http3/api.py +++ b/http3/api.py @@ -8,6 +8,7 @@ from .models import ( HeaderTypes, QueryParamTypes, RequestData, + RequestFiles, Response, URLTypes, ) @@ -18,7 +19,8 @@ def request( url: URLTypes, *, params: QueryParamTypes = None, - data: RequestData = b"", + data: RequestData = None, + files: RequestFiles = None, json: typing.Any = None, headers: HeaderTypes = None, cookies: CookieTypes = None, @@ -36,6 +38,7 @@ def request( method=method, url=url, data=data, + files=files, json=json, params=params, headers=headers, @@ -136,7 +139,8 @@ def head( def post( url: URLTypes, *, - data: RequestData = b"", + data: RequestData = None, + files: RequestFiles = None, json: typing.Any = None, params: QueryParamTypes = None, headers: HeaderTypes = None, @@ -152,6 +156,7 @@ def post( "POST", url, data=data, + files=files, json=json, params=params, headers=headers, @@ -168,7 +173,8 @@ def post( def put( url: URLTypes, *, - data: RequestData = b"", + data: RequestData = None, + files: RequestFiles = None, json: typing.Any = None, params: QueryParamTypes = None, headers: HeaderTypes = None, @@ -184,6 +190,7 @@ def put( "PUT", url, data=data, + files=files, json=json, params=params, headers=headers, @@ -200,7 +207,8 @@ def put( def patch( url: URLTypes, *, - data: RequestData = b"", + data: RequestData = None, + files: RequestFiles = None, json: typing.Any = None, params: QueryParamTypes = None, headers: HeaderTypes = None, @@ -216,6 +224,7 @@ def patch( "PATCH", url, data=data, + files=files, json=json, params=params, headers=headers, @@ -232,7 +241,8 @@ def patch( def delete( url: URLTypes, *, - data: RequestData = b"", + data: RequestData = None, + files: RequestFiles = None, json: typing.Any = None, params: QueryParamTypes = None, headers: HeaderTypes = None, @@ -248,6 +258,7 @@ def delete( "DELETE", url, data=data, + files=files, json=json, params=params, headers=headers, diff --git a/http3/client.py b/http3/client.py index 883f25aa..13c88270 100644 --- a/http3/client.py +++ b/http3/client.py @@ -30,6 +30,7 @@ from .models import ( QueryParamTypes, Request, RequestData, + RequestFiles, Response, ResponseContent, URLTypes, @@ -161,6 +162,7 @@ class BaseClient: await response.close() if response.is_redirect: + async def send_next() -> AsyncResponse: nonlocal request, response, verify, cert, allow_redirects, timeout, history request = self.build_redirect_request(request, response) @@ -342,7 +344,8 @@ class AsyncClient(BaseClient): self, url: URLTypes, *, - data: AsyncRequestData = b"", + data: AsyncRequestData = None, + files: RequestFiles = None, json: typing.Any = None, params: QueryParamTypes = None, headers: HeaderTypes = None, @@ -358,6 +361,7 @@ class AsyncClient(BaseClient): "POST", url, data=data, + files=files, json=json, params=params, headers=headers, @@ -374,7 +378,8 @@ class AsyncClient(BaseClient): self, url: URLTypes, *, - data: AsyncRequestData = b"", + data: AsyncRequestData = None, + files: RequestFiles = None, json: typing.Any = None, params: QueryParamTypes = None, headers: HeaderTypes = None, @@ -390,6 +395,7 @@ class AsyncClient(BaseClient): "PUT", url, data=data, + files=files, json=json, params=params, headers=headers, @@ -406,7 +412,8 @@ class AsyncClient(BaseClient): self, url: URLTypes, *, - data: AsyncRequestData = b"", + data: AsyncRequestData = None, + files: RequestFiles = None, json: typing.Any = None, params: QueryParamTypes = None, headers: HeaderTypes = None, @@ -422,6 +429,7 @@ class AsyncClient(BaseClient): "PATCH", url, data=data, + files=files, json=json, params=params, headers=headers, @@ -438,7 +446,8 @@ class AsyncClient(BaseClient): self, url: URLTypes, *, - data: AsyncRequestData = b"", + data: AsyncRequestData = None, + files: RequestFiles = None, json: typing.Any = None, params: QueryParamTypes = None, headers: HeaderTypes = None, @@ -454,6 +463,7 @@ class AsyncClient(BaseClient): "DELETE", url, data=data, + files=files, json=json, params=params, headers=headers, @@ -471,7 +481,8 @@ class AsyncClient(BaseClient): method: str, url: URLTypes, *, - data: AsyncRequestData = b"", + data: AsyncRequestData = None, + files: RequestFiles = None, json: typing.Any = None, params: QueryParamTypes = None, headers: HeaderTypes = None, @@ -487,6 +498,7 @@ class AsyncClient(BaseClient): method, url, data=data, + files=files, json=json, params=params, headers=headers, @@ -519,12 +531,14 @@ class AsyncClient(BaseClient): class Client(BaseClient): - def _async_request_data(self, data: RequestData) -> AsyncRequestData: + def _async_request_data( + self, data: RequestData = None + ) -> typing.Optional[AsyncRequestData]: """ If the request data is an bytes iterator then return an async bytes iterator onto the request data. """ - if isinstance(data, (bytes, dict)): + if data is None or isinstance(data, (bytes, dict)): return data # Coerce an iterator into an async iterator, with each item in the @@ -546,7 +560,8 @@ class Client(BaseClient): method: str, url: URLTypes, *, - data: RequestData = b"", + data: RequestData = None, + files: RequestFiles = None, json: typing.Any = None, params: QueryParamTypes = None, headers: HeaderTypes = None, @@ -562,6 +577,7 @@ class Client(BaseClient): method, url, data=self._async_request_data(data), + files=files, json=json, params=params, headers=headers, @@ -696,7 +712,8 @@ class Client(BaseClient): self, url: URLTypes, *, - data: RequestData = b"", + data: RequestData = None, + files: RequestFiles = None, json: typing.Any = None, params: QueryParamTypes = None, headers: HeaderTypes = None, @@ -712,6 +729,7 @@ class Client(BaseClient): "POST", url, data=data, + files=files, json=json, params=params, headers=headers, @@ -728,7 +746,8 @@ class Client(BaseClient): self, url: URLTypes, *, - data: RequestData = b"", + data: RequestData = None, + files: RequestFiles = None, json: typing.Any = None, params: QueryParamTypes = None, headers: HeaderTypes = None, @@ -744,6 +763,7 @@ class Client(BaseClient): "PUT", url, data=data, + files=files, json=json, params=params, headers=headers, @@ -760,7 +780,8 @@ class Client(BaseClient): self, url: URLTypes, *, - data: RequestData = b"", + data: RequestData = None, + files: RequestFiles = None, json: typing.Any = None, params: QueryParamTypes = None, headers: HeaderTypes = None, @@ -776,6 +797,7 @@ class Client(BaseClient): "PATCH", url, data=data, + files=files, json=json, params=params, headers=headers, @@ -792,7 +814,8 @@ class Client(BaseClient): self, url: URLTypes, *, - data: RequestData = b"", + data: RequestData = None, + files: RequestFiles = None, json: typing.Any = None, params: QueryParamTypes = None, headers: HeaderTypes = None, @@ -808,6 +831,7 @@ class Client(BaseClient): "DELETE", url, data=data, + files=files, json=json, params=params, headers=headers, diff --git a/http3/models.py b/http3/models.py index 6b22674b..70dad80f 100644 --- a/http3/models.py +++ b/http3/models.py @@ -25,6 +25,7 @@ from .exceptions import ( ResponseNotRead, StreamConsumed, ) +from .multipart import multipart_encode from .status_codes import StatusCode from .utils import is_known_encoding, normalize_header_key, normalize_header_value @@ -54,6 +55,17 @@ AsyncRequestData = typing.Union[dict, bytes, typing.AsyncIterator[bytes]] RequestData = typing.Union[dict, bytes, typing.Iterator[bytes]] +RequestFiles = typing.Dict[ + str, + typing.Union[ + typing.IO[typing.AnyStr], # file + typing.Tuple[str, typing.IO[typing.AnyStr]], # (filename, file) + typing.Tuple[ + str, typing.IO[typing.AnyStr], str + ], # (filename, file, content_type) + ], +] + AsyncResponseContent = typing.Union[bytes, typing.AsyncIterator[bytes]] ResponseContent = typing.Union[bytes, typing.Iterator[bytes]] @@ -489,11 +501,21 @@ class BaseRequest: self._cookies = Cookies(cookies) self._cookies.set_cookie_header(self) - def encode_json(self, json: typing.Any) -> bytes: - return jsonlib.dumps(json).encode("utf-8") - - def urlencode_data(self, data: dict) -> bytes: - return urlencode(data, doseq=True).encode("utf-8") + def encode_data( + self, data: dict = None, files: RequestFiles = None, json: typing.Any = None + ) -> typing.Tuple[bytes, str]: + if json is not None: + content = jsonlib.dumps(json).encode("utf-8") + content_type = "application/json" + elif files is not None: + content, content_type = multipart_encode(data or {}, files) + elif data is not None: + content = urlencode(data, doseq=True).encode("utf-8") + content_type = "application/x-www-form-urlencoded" + else: + content = b"" + content_type = "" + return content, content_type def prepare(self) -> None: content = getattr(self, "content", None) # type: bytes @@ -545,24 +567,23 @@ class AsyncRequest(BaseRequest): params: QueryParamTypes = None, headers: HeaderTypes = None, cookies: CookieTypes = None, - data: AsyncRequestData = b"", + data: AsyncRequestData = None, + files: RequestFiles = None, json: typing.Any = None, ): super().__init__( method=method, url=url, params=params, headers=headers, cookies=cookies ) - if json is not None: + if data is None or isinstance(data, dict): + content, content_type = self.encode_data(data, files, json) self.is_streaming = False - self.content = self.encode_json(json) - self.headers["Content-Type"] = "application/json" + self.content = content + if content_type: + self.headers["Content-Type"] = content_type elif isinstance(data, bytes): self.is_streaming = False self.content = data - elif isinstance(data, dict): - self.is_streaming = False - self.content = self.urlencode_data(data) - self.headers["Content-Type"] = "application/x-www-form-urlencoded" else: assert hasattr(data, "__aiter__") self.is_streaming = True @@ -595,24 +616,23 @@ class Request(BaseRequest): params: QueryParamTypes = None, headers: HeaderTypes = None, cookies: CookieTypes = None, - data: RequestData = b"", + data: RequestData = None, + files: RequestFiles = None, json: typing.Any = None, ): super().__init__( method=method, url=url, params=params, headers=headers, cookies=cookies ) - if json is not None: + if data is None or isinstance(data, dict): + content, content_type = self.encode_data(data, files, json) self.is_streaming = False - self.content = self.encode_json(json) - self.headers["Content-Type"] = "application/json" + self.content = content + if content_type: + self.headers["Content-Type"] = content_type elif isinstance(data, bytes): self.is_streaming = False self.content = data - elif isinstance(data, dict): - self.is_streaming = False - self.content = self.urlencode_data(data) - self.headers["Content-Type"] = "application/x-www-form-urlencoded" else: assert hasattr(data, "__iter__") self.is_streaming = True @@ -798,7 +818,7 @@ class AsyncResponse(BaseResponse): reason_phrase: str = None, protocol: str = None, headers: HeaderTypes = None, - content: AsyncResponseContent = b"", + content: AsyncResponseContent = None, on_close: typing.Callable = None, request: AsyncRequest = None, history: typing.List["BaseResponse"] = None, @@ -814,10 +834,10 @@ class AsyncResponse(BaseResponse): self.history = [] if history is None else list(history) - if isinstance(content, bytes): + if content is None or isinstance(content, bytes): self.is_closed = True self.is_stream_consumed = True - self._raw_content = content + self._raw_content = content or b"" else: self.is_closed = False self.is_stream_consumed = False @@ -879,7 +899,7 @@ class Response(BaseResponse): reason_phrase: str = None, protocol: str = None, headers: HeaderTypes = None, - content: ResponseContent = b"", + content: ResponseContent = None, on_close: typing.Callable = None, request: Request = None, history: typing.List["BaseResponse"] = None, @@ -895,10 +915,10 @@ class Response(BaseResponse): self.history = [] if history is None else list(history) - if isinstance(content, bytes): + if content is None or isinstance(content, bytes): self.is_closed = True self.is_stream_consumed = True - self._raw_content = content + self._raw_content = content or b"" else: self.is_closed = False self.is_stream_consumed = False diff --git a/http3/multipart.py b/http3/multipart.py new file mode 100644 index 00000000..07be1f13 --- /dev/null +++ b/http3/multipart.py @@ -0,0 +1,100 @@ +import binascii +import mimetypes +import os +import typing +from io import BytesIO +from urllib.parse import quote_plus + + +class Field: + def render_headers(self) -> bytes: + raise NotImplementedError() # pragma: nocover + + def render_data(self) -> bytes: + raise NotImplementedError() # pragma: nocover + + +class DataField(Field): + def __init__(self, name: str, value: str) -> None: + self.name = name + self.value = value + + def render_headers(self) -> bytes: + name = quote_plus(self.name, encoding="utf-8").encode("ascii") + return b"".join( + [b'Content-Disposition: form-data; name="', name, b'"\r\n' b"\r\n"] + ) + + def render_data(self) -> bytes: + return quote_plus(self.value, encoding="utf-8").encode("ascii") + + +class FileField(Field): + def __init__( + self, name: str, value: typing.Union[typing.IO[typing.AnyStr], tuple] + ) -> None: + self.name = name + if not isinstance(value, tuple): + self.filename = os.path.basename(getattr(value, "name", "upload")) + self.file = value # type: typing.Union[typing.IO[str], typing.IO[bytes]] + self.content_type = self.guess_content_type() + else: + self.filename = value[0] + self.file = value[1] + self.content_type = ( + value[2] if len(value) > 2 else self.guess_content_type() + ) + + def guess_content_type(self) -> str: + return mimetypes.guess_type(self.filename)[0] or "application/octet-stream" + + def render_headers(self) -> bytes: + name = quote_plus(self.name, encoding="utf-8").encode("ascii") + filename = quote_plus(self.filename, encoding="utf-8").encode("ascii") + content_type = self.content_type.encode("ascii") + return b"".join( + [ + b'Content-Disposition: form-data; name="', + name, + b'"; filename="', + filename, + b'"\r\n', + b"Content-Type: ", + content_type, + b"\r\n", + b"\r\n", + ] + ) + + def render_data(self) -> bytes: + content = self.file.read() + return content.encode("utf-8") if isinstance(content, str) else content + + +def iter_fields(data: dict, files: dict) -> typing.Iterator[Field]: + for name, value in data.items(): + if isinstance(value, list): + for item in value: + yield DataField(name=name, value=item) + else: + yield DataField(name=name, value=value) + + for name, value in files.items(): + yield FileField(name=name, value=value) + + +def multipart_encode(data: dict, files: dict) -> typing.Tuple[bytes, str]: + body = BytesIO() + boundary = binascii.hexlify(os.urandom(16)) + + for field in iter_fields(data, files): + body.write(b"--%s\r\n" % boundary) + body.write(field.render_headers()) + body.write(field.render_data()) + body.write(b"\r\n") + + body.write(b"--%s--\r\n" % boundary) + + content_type = "multipart/form-data; boundary=%s" % boundary.decode("ascii") + + return body.getvalue(), content_type diff --git a/http3/status_codes.py b/http3/status_codes.py index eb5e7ceb..bca2d8ce 100644 --- a/http3/status_codes.py +++ b/http3/status_codes.py @@ -128,6 +128,6 @@ class StatusCode(IntEnum): codes = StatusCode -# Include lower-case styles for `requests` compatability. +#  Include lower-case styles for `requests` compatability. for code in codes: setattr(codes, code._name_.lower(), int(code)) diff --git a/tests/client/test_redirects.py b/tests/client/test_redirects.py index 4d4fd486..30967f33 100644 --- a/tests/client/test_redirects.py +++ b/tests/client/test_redirects.py @@ -168,8 +168,12 @@ async def test_multiple_redirects(): assert response.status_code == codes.OK assert response.url == URL("https://example.org/multiple_redirects") assert len(response.history) == 20 - assert response.history[0].url == URL("https://example.org/multiple_redirects?count=20") - assert response.history[1].url == URL("https://example.org/multiple_redirects?count=19") + assert response.history[0].url == URL( + "https://example.org/multiple_redirects?count=20" + ) + assert response.history[1].url == URL( + "https://example.org/multiple_redirects?count=19" + ) assert len(response.history[0].history) == 0 assert len(response.history[1].history) == 1 diff --git a/tests/test_multipart.py b/tests/test_multipart.py new file mode 100644 index 00000000..50e1511b --- /dev/null +++ b/tests/test_multipart.py @@ -0,0 +1,67 @@ +import cgi +import io + +import pytest + +from http3 import ( + CertTypes, + Client, + Dispatcher, + Request, + Response, + TimeoutTypes, + VerifyTypes, +) + + +class MockDispatch(Dispatcher): + def send( + self, + request: Request, + verify: VerifyTypes = None, + cert: CertTypes = None, + timeout: TimeoutTypes = None, + ) -> Response: + return Response(200, content=request.read()) + + +def test_multipart(): + client = Client(dispatch=MockDispatch()) + + # Test with a single-value 'data' argument, and a plain file 'files' argument. + data = {"text": "abc"} + files = {"file": io.BytesIO(b"")} + response = client.post("http://127.0.0.1:8000/", data=data, files=files) + assert response.status_code == 200 + + # We're using the cgi module to verify the behavior here, which is a + # bit grungy, but sufficient just for our testing purposes. + boundary = response.request.headers["Content-Type"].split("boundary=")[-1] + content_length = response.request.headers["Content-Length"] + pdict = {"boundary": boundary.encode("ascii"), "CONTENT-LENGTH": content_length} + multipart = cgi.parse_multipart(io.BytesIO(response.content), pdict) + + # Note that the expected return type for text fields appears to differs from 3.6 to 3.7+ + assert multipart["text"] == ["abc"] or multipart["text"] == [b"abc"] + assert multipart["file"] == [b""] + + +def test_multipart_file_tuple(): + client = Client(dispatch=MockDispatch()) + + # Test with a list of values 'data' argument, and a tuple style 'files' argument. + data = {"text": ["abc"]} + files = {"file": ("name.txt", io.BytesIO(b""))} + response = client.post("http://127.0.0.1:8000/", data=data, files=files) + assert response.status_code == 200 + + # We're using the cgi module to verify the behavior here, which is a + # bit grungy, but sufficient just for our testing purposes. + boundary = response.request.headers["Content-Type"].split("boundary=")[-1] + content_length = response.request.headers["Content-Length"] + pdict = {"boundary": boundary.encode("ascii"), "CONTENT-LENGTH": content_length} + multipart = cgi.parse_multipart(io.BytesIO(response.content), pdict) + + # Note that the expected return type for text fields appears to differs from 3.6 to 3.7+ + assert multipart["text"] == ["abc"] or multipart["text"] == [b"abc"] + assert multipart["file"] == [b""]