]> git.ipfire.org Git - thirdparty/httpx.git/commitdiff
Multipart support (#90)
authorTom Christie <tom@tomchristie.com>
Mon, 17 Jun 2019 13:38:24 +0000 (14:38 +0100)
committerGitHub <noreply@github.com>
Mon, 17 Jun 2019 13:38:24 +0000 (14:38 +0100)
* Multipart support

* Test compat with 3.6

http3/api.py
http3/client.py
http3/models.py
http3/multipart.py [new file with mode: 0644]
http3/status_codes.py
tests/client/test_redirects.py
tests/test_multipart.py [new file with mode: 0644]

index 4b77e3bd589b99dc463017c16b79588e367c40dc..99d60128ab609bc4a0c098f0b4c48dc47bf8799c 100644 (file)
@@ -8,6 +8,7 @@ from .models import (
     HeaderTypes,
     QueryParamTypes,
     RequestData,
+    RequestFiles,
     Response,
     URLTypes,
 )
@@ -18,7 +19,8 @@ def request(
     url: URLTypes,
     *,
     params: QueryParamTypes = None,
-    data: RequestData = b"",
+    data: RequestData = None,
+    files: RequestFiles = None,
     json: typing.Any = None,
     headers: HeaderTypes = None,
     cookies: CookieTypes = None,
@@ -36,6 +38,7 @@ def request(
             method=method,
             url=url,
             data=data,
+            files=files,
             json=json,
             params=params,
             headers=headers,
@@ -136,7 +139,8 @@ def head(
 def post(
     url: URLTypes,
     *,
-    data: RequestData = b"",
+    data: RequestData = None,
+    files: RequestFiles = None,
     json: typing.Any = None,
     params: QueryParamTypes = None,
     headers: HeaderTypes = None,
@@ -152,6 +156,7 @@ def post(
         "POST",
         url,
         data=data,
+        files=files,
         json=json,
         params=params,
         headers=headers,
@@ -168,7 +173,8 @@ def post(
 def put(
     url: URLTypes,
     *,
-    data: RequestData = b"",
+    data: RequestData = None,
+    files: RequestFiles = None,
     json: typing.Any = None,
     params: QueryParamTypes = None,
     headers: HeaderTypes = None,
@@ -184,6 +190,7 @@ def put(
         "PUT",
         url,
         data=data,
+        files=files,
         json=json,
         params=params,
         headers=headers,
@@ -200,7 +207,8 @@ def put(
 def patch(
     url: URLTypes,
     *,
-    data: RequestData = b"",
+    data: RequestData = None,
+    files: RequestFiles = None,
     json: typing.Any = None,
     params: QueryParamTypes = None,
     headers: HeaderTypes = None,
@@ -216,6 +224,7 @@ def patch(
         "PATCH",
         url,
         data=data,
+        files=files,
         json=json,
         params=params,
         headers=headers,
@@ -232,7 +241,8 @@ def patch(
 def delete(
     url: URLTypes,
     *,
-    data: RequestData = b"",
+    data: RequestData = None,
+    files: RequestFiles = None,
     json: typing.Any = None,
     params: QueryParamTypes = None,
     headers: HeaderTypes = None,
@@ -248,6 +258,7 @@ def delete(
         "DELETE",
         url,
         data=data,
+        files=files,
         json=json,
         params=params,
         headers=headers,
index 883f25aad2da0f7e89d93abdd24e5ba284160304..13c88270ce11c079ddc331bae18e297da21e2eb5 100644 (file)
@@ -30,6 +30,7 @@ from .models import (
     QueryParamTypes,
     Request,
     RequestData,
+    RequestFiles,
     Response,
     ResponseContent,
     URLTypes,
@@ -161,6 +162,7 @@ class BaseClient:
                     await response.close()
 
         if response.is_redirect:
+
             async def send_next() -> AsyncResponse:
                 nonlocal request, response, verify, cert, allow_redirects, timeout, history
                 request = self.build_redirect_request(request, response)
@@ -342,7 +344,8 @@ class AsyncClient(BaseClient):
         self,
         url: URLTypes,
         *,
-        data: AsyncRequestData = b"",
+        data: AsyncRequestData = None,
+        files: RequestFiles = None,
         json: typing.Any = None,
         params: QueryParamTypes = None,
         headers: HeaderTypes = None,
@@ -358,6 +361,7 @@ class AsyncClient(BaseClient):
             "POST",
             url,
             data=data,
+            files=files,
             json=json,
             params=params,
             headers=headers,
@@ -374,7 +378,8 @@ class AsyncClient(BaseClient):
         self,
         url: URLTypes,
         *,
-        data: AsyncRequestData = b"",
+        data: AsyncRequestData = None,
+        files: RequestFiles = None,
         json: typing.Any = None,
         params: QueryParamTypes = None,
         headers: HeaderTypes = None,
@@ -390,6 +395,7 @@ class AsyncClient(BaseClient):
             "PUT",
             url,
             data=data,
+            files=files,
             json=json,
             params=params,
             headers=headers,
@@ -406,7 +412,8 @@ class AsyncClient(BaseClient):
         self,
         url: URLTypes,
         *,
-        data: AsyncRequestData = b"",
+        data: AsyncRequestData = None,
+        files: RequestFiles = None,
         json: typing.Any = None,
         params: QueryParamTypes = None,
         headers: HeaderTypes = None,
@@ -422,6 +429,7 @@ class AsyncClient(BaseClient):
             "PATCH",
             url,
             data=data,
+            files=files,
             json=json,
             params=params,
             headers=headers,
@@ -438,7 +446,8 @@ class AsyncClient(BaseClient):
         self,
         url: URLTypes,
         *,
-        data: AsyncRequestData = b"",
+        data: AsyncRequestData = None,
+        files: RequestFiles = None,
         json: typing.Any = None,
         params: QueryParamTypes = None,
         headers: HeaderTypes = None,
@@ -454,6 +463,7 @@ class AsyncClient(BaseClient):
             "DELETE",
             url,
             data=data,
+            files=files,
             json=json,
             params=params,
             headers=headers,
@@ -471,7 +481,8 @@ class AsyncClient(BaseClient):
         method: str,
         url: URLTypes,
         *,
-        data: AsyncRequestData = b"",
+        data: AsyncRequestData = None,
+        files: RequestFiles = None,
         json: typing.Any = None,
         params: QueryParamTypes = None,
         headers: HeaderTypes = None,
@@ -487,6 +498,7 @@ class AsyncClient(BaseClient):
             method,
             url,
             data=data,
+            files=files,
             json=json,
             params=params,
             headers=headers,
@@ -519,12 +531,14 @@ class AsyncClient(BaseClient):
 
 
 class Client(BaseClient):
-    def _async_request_data(self, data: RequestData) -> AsyncRequestData:
+    def _async_request_data(
+        self, data: RequestData = None
+    ) -> typing.Optional[AsyncRequestData]:
         """
         If the request data is an bytes iterator then return an async bytes
         iterator onto the request data.
         """
-        if isinstance(data, (bytes, dict)):
+        if data is None or isinstance(data, (bytes, dict)):
             return data
 
         # Coerce an iterator into an async iterator, with each item in the
@@ -546,7 +560,8 @@ class Client(BaseClient):
         method: str,
         url: URLTypes,
         *,
-        data: RequestData = b"",
+        data: RequestData = None,
+        files: RequestFiles = None,
         json: typing.Any = None,
         params: QueryParamTypes = None,
         headers: HeaderTypes = None,
@@ -562,6 +577,7 @@ class Client(BaseClient):
             method,
             url,
             data=self._async_request_data(data),
+            files=files,
             json=json,
             params=params,
             headers=headers,
@@ -696,7 +712,8 @@ class Client(BaseClient):
         self,
         url: URLTypes,
         *,
-        data: RequestData = b"",
+        data: RequestData = None,
+        files: RequestFiles = None,
         json: typing.Any = None,
         params: QueryParamTypes = None,
         headers: HeaderTypes = None,
@@ -712,6 +729,7 @@ class Client(BaseClient):
             "POST",
             url,
             data=data,
+            files=files,
             json=json,
             params=params,
             headers=headers,
@@ -728,7 +746,8 @@ class Client(BaseClient):
         self,
         url: URLTypes,
         *,
-        data: RequestData = b"",
+        data: RequestData = None,
+        files: RequestFiles = None,
         json: typing.Any = None,
         params: QueryParamTypes = None,
         headers: HeaderTypes = None,
@@ -744,6 +763,7 @@ class Client(BaseClient):
             "PUT",
             url,
             data=data,
+            files=files,
             json=json,
             params=params,
             headers=headers,
@@ -760,7 +780,8 @@ class Client(BaseClient):
         self,
         url: URLTypes,
         *,
-        data: RequestData = b"",
+        data: RequestData = None,
+        files: RequestFiles = None,
         json: typing.Any = None,
         params: QueryParamTypes = None,
         headers: HeaderTypes = None,
@@ -776,6 +797,7 @@ class Client(BaseClient):
             "PATCH",
             url,
             data=data,
+            files=files,
             json=json,
             params=params,
             headers=headers,
@@ -792,7 +814,8 @@ class Client(BaseClient):
         self,
         url: URLTypes,
         *,
-        data: RequestData = b"",
+        data: RequestData = None,
+        files: RequestFiles = None,
         json: typing.Any = None,
         params: QueryParamTypes = None,
         headers: HeaderTypes = None,
@@ -808,6 +831,7 @@ class Client(BaseClient):
             "DELETE",
             url,
             data=data,
+            files=files,
             json=json,
             params=params,
             headers=headers,
index 6b22674b19b2e8d6cca384ac26cb606e5af3abb4..70dad80f9006a4f3e99a9902ea6777241ed6a65d 100644 (file)
@@ -25,6 +25,7 @@ from .exceptions import (
     ResponseNotRead,
     StreamConsumed,
 )
+from .multipart import multipart_encode
 from .status_codes import StatusCode
 from .utils import is_known_encoding, normalize_header_key, normalize_header_value
 
@@ -54,6 +55,17 @@ AsyncRequestData = typing.Union[dict, bytes, typing.AsyncIterator[bytes]]
 
 RequestData = typing.Union[dict, bytes, typing.Iterator[bytes]]
 
+RequestFiles = typing.Dict[
+    str,
+    typing.Union[
+        typing.IO[typing.AnyStr],  # file
+        typing.Tuple[str, typing.IO[typing.AnyStr]],  # (filename, file)
+        typing.Tuple[
+            str, typing.IO[typing.AnyStr], str
+        ],  # (filename, file, content_type)
+    ],
+]
+
 AsyncResponseContent = typing.Union[bytes, typing.AsyncIterator[bytes]]
 
 ResponseContent = typing.Union[bytes, typing.Iterator[bytes]]
@@ -489,11 +501,21 @@ class BaseRequest:
             self._cookies = Cookies(cookies)
             self._cookies.set_cookie_header(self)
 
-    def encode_json(self, json: typing.Any) -> bytes:
-        return jsonlib.dumps(json).encode("utf-8")
-
-    def urlencode_data(self, data: dict) -> bytes:
-        return urlencode(data, doseq=True).encode("utf-8")
+    def encode_data(
+        self, data: dict = None, files: RequestFiles = None, json: typing.Any = None
+    ) -> typing.Tuple[bytes, str]:
+        if json is not None:
+            content = jsonlib.dumps(json).encode("utf-8")
+            content_type = "application/json"
+        elif files is not None:
+            content, content_type = multipart_encode(data or {}, files)
+        elif data is not None:
+            content = urlencode(data, doseq=True).encode("utf-8")
+            content_type = "application/x-www-form-urlencoded"
+        else:
+            content = b""
+            content_type = ""
+        return content, content_type
 
     def prepare(self) -> None:
         content = getattr(self, "content", None)  # type: bytes
@@ -545,24 +567,23 @@ class AsyncRequest(BaseRequest):
         params: QueryParamTypes = None,
         headers: HeaderTypes = None,
         cookies: CookieTypes = None,
-        data: AsyncRequestData = b"",
+        data: AsyncRequestData = None,
+        files: RequestFiles = None,
         json: typing.Any = None,
     ):
         super().__init__(
             method=method, url=url, params=params, headers=headers, cookies=cookies
         )
 
-        if json is not None:
+        if data is None or isinstance(data, dict):
+            content, content_type = self.encode_data(data, files, json)
             self.is_streaming = False
-            self.content = self.encode_json(json)
-            self.headers["Content-Type"] = "application/json"
+            self.content = content
+            if content_type:
+                self.headers["Content-Type"] = content_type
         elif isinstance(data, bytes):
             self.is_streaming = False
             self.content = data
-        elif isinstance(data, dict):
-            self.is_streaming = False
-            self.content = self.urlencode_data(data)
-            self.headers["Content-Type"] = "application/x-www-form-urlencoded"
         else:
             assert hasattr(data, "__aiter__")
             self.is_streaming = True
@@ -595,24 +616,23 @@ class Request(BaseRequest):
         params: QueryParamTypes = None,
         headers: HeaderTypes = None,
         cookies: CookieTypes = None,
-        data: RequestData = b"",
+        data: RequestData = None,
+        files: RequestFiles = None,
         json: typing.Any = None,
     ):
         super().__init__(
             method=method, url=url, params=params, headers=headers, cookies=cookies
         )
 
-        if json is not None:
+        if data is None or isinstance(data, dict):
+            content, content_type = self.encode_data(data, files, json)
             self.is_streaming = False
-            self.content = self.encode_json(json)
-            self.headers["Content-Type"] = "application/json"
+            self.content = content
+            if content_type:
+                self.headers["Content-Type"] = content_type
         elif isinstance(data, bytes):
             self.is_streaming = False
             self.content = data
-        elif isinstance(data, dict):
-            self.is_streaming = False
-            self.content = self.urlencode_data(data)
-            self.headers["Content-Type"] = "application/x-www-form-urlencoded"
         else:
             assert hasattr(data, "__iter__")
             self.is_streaming = True
@@ -798,7 +818,7 @@ class AsyncResponse(BaseResponse):
         reason_phrase: str = None,
         protocol: str = None,
         headers: HeaderTypes = None,
-        content: AsyncResponseContent = b"",
+        content: AsyncResponseContent = None,
         on_close: typing.Callable = None,
         request: AsyncRequest = None,
         history: typing.List["BaseResponse"] = None,
@@ -814,10 +834,10 @@ class AsyncResponse(BaseResponse):
 
         self.history = [] if history is None else list(history)
 
-        if isinstance(content, bytes):
+        if content is None or isinstance(content, bytes):
             self.is_closed = True
             self.is_stream_consumed = True
-            self._raw_content = content
+            self._raw_content = content or b""
         else:
             self.is_closed = False
             self.is_stream_consumed = False
@@ -879,7 +899,7 @@ class Response(BaseResponse):
         reason_phrase: str = None,
         protocol: str = None,
         headers: HeaderTypes = None,
-        content: ResponseContent = b"",
+        content: ResponseContent = None,
         on_close: typing.Callable = None,
         request: Request = None,
         history: typing.List["BaseResponse"] = None,
@@ -895,10 +915,10 @@ class Response(BaseResponse):
 
         self.history = [] if history is None else list(history)
 
-        if isinstance(content, bytes):
+        if content is None or isinstance(content, bytes):
             self.is_closed = True
             self.is_stream_consumed = True
-            self._raw_content = content
+            self._raw_content = content or b""
         else:
             self.is_closed = False
             self.is_stream_consumed = False
diff --git a/http3/multipart.py b/http3/multipart.py
new file mode 100644 (file)
index 0000000..07be1f1
--- /dev/null
@@ -0,0 +1,100 @@
+import binascii
+import mimetypes
+import os
+import typing
+from io import BytesIO
+from urllib.parse import quote_plus
+
+
+class Field:
+    def render_headers(self) -> bytes:
+        raise NotImplementedError()  # pragma: nocover
+
+    def render_data(self) -> bytes:
+        raise NotImplementedError()  # pragma: nocover
+
+
+class DataField(Field):
+    def __init__(self, name: str, value: str) -> None:
+        self.name = name
+        self.value = value
+
+    def render_headers(self) -> bytes:
+        name = quote_plus(self.name, encoding="utf-8").encode("ascii")
+        return b"".join(
+            [b'Content-Disposition: form-data; name="', name, b'"\r\n' b"\r\n"]
+        )
+
+    def render_data(self) -> bytes:
+        return quote_plus(self.value, encoding="utf-8").encode("ascii")
+
+
+class FileField(Field):
+    def __init__(
+        self, name: str, value: typing.Union[typing.IO[typing.AnyStr], tuple]
+    ) -> None:
+        self.name = name
+        if not isinstance(value, tuple):
+            self.filename = os.path.basename(getattr(value, "name", "upload"))
+            self.file = value  # type: typing.Union[typing.IO[str], typing.IO[bytes]]
+            self.content_type = self.guess_content_type()
+        else:
+            self.filename = value[0]
+            self.file = value[1]
+            self.content_type = (
+                value[2] if len(value) > 2 else self.guess_content_type()
+            )
+
+    def guess_content_type(self) -> str:
+        return mimetypes.guess_type(self.filename)[0] or "application/octet-stream"
+
+    def render_headers(self) -> bytes:
+        name = quote_plus(self.name, encoding="utf-8").encode("ascii")
+        filename = quote_plus(self.filename, encoding="utf-8").encode("ascii")
+        content_type = self.content_type.encode("ascii")
+        return b"".join(
+            [
+                b'Content-Disposition: form-data; name="',
+                name,
+                b'"; filename="',
+                filename,
+                b'"\r\n',
+                b"Content-Type: ",
+                content_type,
+                b"\r\n",
+                b"\r\n",
+            ]
+        )
+
+    def render_data(self) -> bytes:
+        content = self.file.read()
+        return content.encode("utf-8") if isinstance(content, str) else content
+
+
+def iter_fields(data: dict, files: dict) -> typing.Iterator[Field]:
+    for name, value in data.items():
+        if isinstance(value, list):
+            for item in value:
+                yield DataField(name=name, value=item)
+        else:
+            yield DataField(name=name, value=value)
+
+    for name, value in files.items():
+        yield FileField(name=name, value=value)
+
+
+def multipart_encode(data: dict, files: dict) -> typing.Tuple[bytes, str]:
+    body = BytesIO()
+    boundary = binascii.hexlify(os.urandom(16))
+
+    for field in iter_fields(data, files):
+        body.write(b"--%s\r\n" % boundary)
+        body.write(field.render_headers())
+        body.write(field.render_data())
+        body.write(b"\r\n")
+
+    body.write(b"--%s--\r\n" % boundary)
+
+    content_type = "multipart/form-data; boundary=%s" % boundary.decode("ascii")
+
+    return body.getvalue(), content_type
index eb5e7cebc5d73d711bca79fe582eaa88a9aeeb60..bca2d8ce5ecb36ce5998cd237eeb00cb1911fbd5 100644 (file)
@@ -128,6 +128,6 @@ class StatusCode(IntEnum):
 
 codes = StatusCode
 
-# Include lower-case styles for `requests` compatability.
+#  Include lower-case styles for `requests` compatability.
 for code in codes:
     setattr(codes, code._name_.lower(), int(code))
index 4d4fd48659b77932471e55dd7bd5be501947b536..30967f3369a7aa253641c36fed437735a9d57215 100644 (file)
@@ -168,8 +168,12 @@ async def test_multiple_redirects():
     assert response.status_code == codes.OK
     assert response.url == URL("https://example.org/multiple_redirects")
     assert len(response.history) == 20
-    assert response.history[0].url == URL("https://example.org/multiple_redirects?count=20")
-    assert response.history[1].url == URL("https://example.org/multiple_redirects?count=19")
+    assert response.history[0].url == URL(
+        "https://example.org/multiple_redirects?count=20"
+    )
+    assert response.history[1].url == URL(
+        "https://example.org/multiple_redirects?count=19"
+    )
     assert len(response.history[0].history) == 0
     assert len(response.history[1].history) == 1
 
diff --git a/tests/test_multipart.py b/tests/test_multipart.py
new file mode 100644 (file)
index 0000000..50e1511
--- /dev/null
@@ -0,0 +1,67 @@
+import cgi
+import io
+
+import pytest
+
+from http3 import (
+    CertTypes,
+    Client,
+    Dispatcher,
+    Request,
+    Response,
+    TimeoutTypes,
+    VerifyTypes,
+)
+
+
+class MockDispatch(Dispatcher):
+    def send(
+        self,
+        request: Request,
+        verify: VerifyTypes = None,
+        cert: CertTypes = None,
+        timeout: TimeoutTypes = None,
+    ) -> Response:
+        return Response(200, content=request.read())
+
+
+def test_multipart():
+    client = Client(dispatch=MockDispatch())
+
+    # Test with a single-value 'data' argument, and a plain file 'files' argument.
+    data = {"text": "abc"}
+    files = {"file": io.BytesIO(b"<file content>")}
+    response = client.post("http://127.0.0.1:8000/", data=data, files=files)
+    assert response.status_code == 200
+
+    # We're using the cgi module to verify the behavior here, which is a
+    # bit grungy, but sufficient just for our testing purposes.
+    boundary = response.request.headers["Content-Type"].split("boundary=")[-1]
+    content_length = response.request.headers["Content-Length"]
+    pdict = {"boundary": boundary.encode("ascii"), "CONTENT-LENGTH": content_length}
+    multipart = cgi.parse_multipart(io.BytesIO(response.content), pdict)
+
+    # Note that the expected return type for text fields appears to differs from 3.6 to 3.7+
+    assert multipart["text"] == ["abc"] or multipart["text"] == [b"abc"]
+    assert multipart["file"] == [b"<file content>"]
+
+
+def test_multipart_file_tuple():
+    client = Client(dispatch=MockDispatch())
+
+    # Test with a list of values 'data' argument, and a tuple style 'files' argument.
+    data = {"text": ["abc"]}
+    files = {"file": ("name.txt", io.BytesIO(b"<file content>"))}
+    response = client.post("http://127.0.0.1:8000/", data=data, files=files)
+    assert response.status_code == 200
+
+    # We're using the cgi module to verify the behavior here, which is a
+    # bit grungy, but sufficient just for our testing purposes.
+    boundary = response.request.headers["Content-Type"].split("boundary=")[-1]
+    content_length = response.request.headers["Content-Length"]
+    pdict = {"boundary": boundary.encode("ascii"), "CONTENT-LENGTH": content_length}
+    multipart = cgi.parse_multipart(io.BytesIO(response.content), pdict)
+
+    # Note that the expected return type for text fields appears to differs from 3.6 to 3.7+
+    assert multipart["text"] == ["abc"] or multipart["text"] == [b"abc"]
+    assert multipart["file"] == [b"<file content>"]