]> git.ipfire.org Git - thirdparty/httpx.git/commitdiff
Added RequestContent (#636)
authorTom Christie <tom@tomchristie.com>
Wed, 18 Dec 2019 10:51:34 +0000 (10:51 +0000)
committerGitHub <noreply@github.com>
Wed, 18 Dec 2019 10:51:34 +0000 (10:51 +0000)
* Request Content

* Added RequestContent interface

* Docstrings on 'encode(data, files, json)'

* Update httpx/content.py

Co-Authored-By: Florimond Manca <florimond.manca@gmail.com>
* can_rewind -> can_replay

httpx/client.py
httpx/content.py [new file with mode: 0644]
httpx/models.py
httpx/multipart.py
tests/client/test_redirects.py
tests/dispatch/test_http2.py
tests/dispatch/utils.py
tests/models/test_requests.py
tests/test_content.py [new file with mode: 0644]

index 4da7ba15a80ca004f39a76019e7cee4adde2c389..ab86be15c254b055d0e7e23213e460141364b137 100644 (file)
@@ -19,6 +19,7 @@ from .config import (
     UnsetType,
     VerifyTypes,
 )
+from .content import RequestContent
 from .dispatch.asgi import ASGIDispatch
 from .dispatch.base import Dispatcher
 from .dispatch.connection_pool import ConnectionPool
@@ -500,9 +501,9 @@ class Client:
         headers = self.redirect_headers(request, url, method)
         content = self.redirect_content(request, method)
         cookies = Cookies(self.cookies)
-        return Request(
-            method=method, url=url, headers=headers, data=content, cookies=cookies
-        )
+        request = Request(method=method, url=url, headers=headers, cookies=cookies)
+        request.content = content
+        return request
 
     def redirect_method(self, request: Request, response: Response) -> str:
         """
@@ -570,13 +571,13 @@ class Client:
 
         return headers
 
-    def redirect_content(self, request: Request, method: str) -> bytes:
+    def redirect_content(self, request: Request, method: str) -> RequestContent:
         """
         Return the body that should be used for the redirect request.
         """
         if method != request.method and method == "GET":
-            return b""
-        if request.is_streaming:
+            return RequestContent()
+        if not request.content.can_replay():
             raise RedirectBodyUnavailable()
         return request.content
 
diff --git a/httpx/content.py b/httpx/content.py
new file mode 100644 (file)
index 0000000..7fe106d
--- /dev/null
@@ -0,0 +1,171 @@
+import typing
+from json import dumps as json_dumps
+from urllib.parse import urlencode
+
+from .multipart import multipart_encode
+
+RequestData = typing.Union[dict, str, bytes, typing.AsyncIterator[bytes]]
+
+RequestFiles = typing.Dict[
+    str,
+    typing.Union[
+        # file (or str)
+        typing.Union[typing.IO[typing.AnyStr], typing.AnyStr],
+        # (filename, file (or str))
+        typing.Tuple[
+            typing.Optional[str], typing.Union[typing.IO[typing.AnyStr], typing.AnyStr],
+        ],
+        # (filename, file (or str), content_type)
+        typing.Tuple[
+            typing.Optional[str],
+            typing.Union[typing.IO[typing.AnyStr], typing.AnyStr],
+            typing.Optional[str],
+        ],
+    ],
+]
+
+
+class RequestContent:
+    """
+    Base class for request content.
+    Defaults to a "no request body" implementation.
+    """
+
+    def get_headers(self) -> typing.Dict[str, str]:
+        """
+        Return a dictionary of request headers that are implied by the encoding.
+        """
+        return {}
+
+    def can_replay(self) -> bool:
+        """
+        Return `True` if `__aiter__` can be called multiple times.
+
+        We need this in order to determine if we can re-issue a request body
+        when we receive a redirect response.
+        """
+        return True
+
+    async def __aiter__(self) -> typing.AsyncIterator[bytes]:
+        yield b""
+
+    async def aread(self) -> bytes:
+        return b"".join([part async for part in self])
+
+
+class BytesRequestContent(RequestContent):
+    """
+    Request content encoded as plain bytes.
+    """
+
+    def __init__(self, body: typing.Union[str, bytes]) -> None:
+        self.body = body.encode("utf-8") if isinstance(body, str) else body
+
+    def get_headers(self) -> typing.Dict[str, str]:
+        content_length = str(len(self.body))
+        return {"Content-Length": content_length}
+
+    async def __aiter__(self) -> typing.AsyncIterator[bytes]:
+        yield self.body
+
+
+class StreamingRequestContent(RequestContent):
+    """
+    Request content encoded as plain bytes, using an async byte iterator.
+    """
+
+    def __init__(self, aiterator: typing.AsyncIterator[bytes]) -> None:
+        self.aiterator = aiterator
+
+    def can_replay(self) -> bool:
+        return False
+
+    def get_headers(self) -> typing.Dict[str, str]:
+        return {"Transfer-Encoding": "chunked"}
+
+    async def __aiter__(self) -> typing.AsyncIterator[bytes]:
+        async for part in self.aiterator:
+            yield part
+
+
+class JSONRequestContent(RequestContent):
+    """
+    Request content encoded as JSON.
+    """
+
+    def __init__(self, json: typing.Any) -> None:
+        self.body = json_dumps(json).encode("utf-8")
+
+    def get_headers(self) -> typing.Dict[str, str]:
+        content_length = str(len(self.body))
+        content_type = "application/json"
+        return {"Content-Length": content_length, "Content-Type": content_type}
+
+    async def __aiter__(self) -> typing.AsyncIterator[bytes]:
+        yield self.body
+
+
+class URLEncodedRequestContent(RequestContent):
+    """
+    Request content as URL encoded form data.
+    """
+
+    def __init__(self, data: dict) -> None:
+        self.body = urlencode(data, doseq=True).encode("utf-8")
+
+    def get_headers(self) -> typing.Dict[str, str]:
+        content_length = str(len(self.body))
+        content_type = "application/x-www-form-urlencoded"
+        return {"Content-Length": content_length, "Content-Type": content_type}
+
+    async def __aiter__(self) -> typing.AsyncIterator[bytes]:
+        yield self.body
+
+
+class MultipartRequestContent(RequestContent):
+    """
+    Request content as multipart encoded form data.
+    """
+
+    def __init__(self, data: dict, files: dict, boundary: bytes = None) -> None:
+        self.body, self.content_type = multipart_encode(data, files, boundary)
+
+    def get_headers(self) -> typing.Dict[str, str]:
+        content_length = str(len(self.body))
+        content_type = self.content_type
+        return {"Content-Length": content_length, "Content-Type": content_type}
+
+    async def __aiter__(self) -> typing.AsyncIterator[bytes]:
+        yield self.body
+
+
+def encode(
+    data: RequestData = None,
+    files: RequestFiles = None,
+    json: typing.Any = None,
+    boundary: bytes = None,
+) -> RequestContent:
+    """
+    Handles encoding the given `data`, `files`, and `json`, returning
+    a `RequestContent` implementation which provides a byte iterator onto
+    the content, as well as `.is_rewindable()` and `.get_headers()` interfaces.
+
+    The `boundary` argument is also included for reproducible test cases
+    when working with multipart data.
+    """
+    if data is None:
+        if json is not None:
+            return JSONRequestContent(json)
+        elif files:
+            return MultipartRequestContent({}, files, boundary=boundary)
+        else:
+            return RequestContent()
+    elif isinstance(data, dict):
+        if files is not None:
+            return MultipartRequestContent(data, files, boundary=boundary)
+        else:
+            return URLEncodedRequestContent(data)
+    elif isinstance(data, (str, bytes)):
+        return BytesRequestContent(data)
+    else:
+        return StreamingRequestContent(data)
index 10af37c5323aa38f891bf6870400c6028d0d9f4e..f936da5eaca3c109304877287d6dfddf9605fa8c 100644 (file)
@@ -13,6 +13,7 @@ import chardet
 import rfc3986
 
 from .config import USER_AGENT
+from .content import RequestData, RequestFiles, encode
 from .decoders import (
     ACCEPT_ENCODING,
     SUPPORTED_DECODERS,
@@ -31,7 +32,6 @@ from .exceptions import (
     ResponseNotRead,
     StreamConsumed,
 )
-from .multipart import multipart_encode
 from .status_codes import StatusCode
 from .utils import (
     flatten_queryparams,
@@ -77,26 +77,6 @@ ProxiesTypes = typing.Union[
     URLTypes, "Dispatcher", typing.Dict[URLTypes, typing.Union[URLTypes, "Dispatcher"]]
 ]
 
-RequestData = typing.Union[dict, str, bytes, typing.AsyncIterator[bytes]]
-
-RequestFiles = typing.Dict[
-    str,
-    typing.Union[
-        # file (or str)
-        typing.Union[typing.IO[typing.AnyStr], typing.AnyStr],
-        # (filename, file (or str))
-        typing.Tuple[
-            typing.Optional[str], typing.Union[typing.IO[typing.AnyStr], typing.AnyStr],
-        ],
-        # (filename, file (or str), content_type)
-        typing.Tuple[
-            typing.Optional[str],
-            typing.Union[typing.IO[typing.AnyStr], typing.AnyStr],
-            typing.Optional[str],
-        ],
-    ],
-]
-
 ResponseContent = typing.Union[bytes, typing.AsyncIterator[bytes]]
 
 
@@ -619,50 +599,18 @@ class Request:
         if cookies:
             self._cookies = Cookies(cookies)
             self._cookies.set_cookie_header(self)
-        if data is None or isinstance(data, dict):
-            content, content_type = self.encode_data(data, files, json)
-            self.is_streaming = False
-            self.content = content
-            if content_type:
-                self.headers.setdefault("Content-Type", content_type)
-        elif isinstance(data, (str, bytes)):
-            data = data.encode("utf-8") if isinstance(data, str) else data
-            self.is_streaming = False
-            self.content = data
-        else:
-            assert hasattr(data, "__aiter__")
-            self.is_streaming = True
-            self.content_aiter = data
+        self.content = encode(data, files, json)
         self.prepare()
 
-    def encode_data(
-        self, data: dict = None, files: RequestFiles = None, json: typing.Any = None
-    ) -> typing.Tuple[bytes, str]:
-        if json is not None:
-            content = jsonlib.dumps(json).encode("utf-8")
-            content_type = "application/json"
-        elif files is not None:
-            content, content_type = multipart_encode(data or {}, files)
-        elif data is not None:
-            content = urlencode(data, doseq=True).encode("utf-8")
-            content_type = "application/x-www-form-urlencoded"
-        else:
-            content = b""
-            content_type = ""
-        return content, content_type
-
     def prepare(self) -> None:
-        content: typing.Optional[bytes] = getattr(self, "content", None)
-        is_streaming = getattr(self, "is_streaming", False)
+        for key, value in self.content.get_headers().items():
+            self.headers.setdefault(key, value)
 
         auto_headers: typing.List[typing.Tuple[bytes, bytes]] = []
 
         has_host = "host" in self.headers
         has_user_agent = "user-agent" in self.headers
         has_accept = "accept" in self.headers
-        has_content_length = (
-            "content-length" in self.headers or "transfer-encoding" in self.headers
-        )
         has_accept_encoding = "accept-encoding" in self.headers
         has_connection = "connection" in self.headers
 
@@ -675,12 +623,6 @@ class Request:
             auto_headers.append((b"user-agent", USER_AGENT.encode("ascii")))
         if not has_accept:
             auto_headers.append((b"accept", b"*/*"))
-        if not has_content_length:
-            if is_streaming:
-                auto_headers.append((b"transfer-encoding", b"chunked"))
-            elif content:
-                content_length = str(len(content)).encode()
-                auto_headers.append((b"content-length", content_length))
         if not has_accept_encoding:
             auto_headers.append((b"accept-encoding", ACCEPT_ENCODING.encode()))
         if not has_connection:
@@ -704,16 +646,11 @@ class Request:
         """
         Read and return the request content.
         """
-        if not hasattr(self, "content"):
-            self.content = b"".join([part async for part in self.stream()])
-        return self.content
+        return await self.content.aread()
 
     async def stream(self) -> typing.AsyncIterator[bytes]:
-        if self.is_streaming:
-            async for part in self.content_aiter:
-                yield part
-        elif self.content:
-            yield self.content
+        async for part in self.content:
+            yield part
 
 
 class Response:
index a46427942a11138d13409f45590a5fba3d407d97..1f0f4a36de343e0dd660a8c27996ce5709959770 100644 (file)
@@ -95,9 +95,12 @@ def iter_fields(data: dict, files: dict) -> typing.Iterator[Field]:
         yield FileField(name=name, value=value)
 
 
-def multipart_encode(data: dict, files: dict) -> typing.Tuple[bytes, str]:
+def multipart_encode(
+    data: dict, files: dict, boundary: bytes = None
+) -> typing.Tuple[bytes, str]:
     body = BytesIO()
-    boundary = binascii.hexlify(os.urandom(16))
+    if boundary is None:
+        boundary = binascii.hexlify(os.urandom(16))
 
     for field in iter_fields(data, files):
         body.write(b"--%s\r\n" % boundary)
index ec8514195a281e29d613fb29762344b75ea409de..7208a1cfec1466c72ef443a2fff9da43d77dce3a 100644 (file)
@@ -77,7 +77,6 @@ class MockDispatch(Dispatcher):
             return Response(codes.OK, content=content, request=request)
 
         elif request.url.path == "/redirect_body":
-            await request.read()
             headers = {"location": "/redirect_body_target"}
             return Response(codes.PERMANENT_REDIRECT, headers=headers, request=request)
 
index a9c450975613e00c396e8393b7e780129b2a69f2..5aae075d17e704099ec2eaf3ecf563ac5ab96ba6 100644 (file)
@@ -11,13 +11,12 @@ from httpx import Client, Response, TimeoutException
 from .utils import MockHTTP2Backend
 
 
-def app(request):
+async def app(request):
+    method = request.method
+    path = request.url.path
+    body = await request.read()
     content = json.dumps(
-        {
-            "method": request.method,
-            "path": request.url.path,
-            "body": request.content.decode(),
-        }
+        {"method": method, "path": path, "body": body.decode()}
     ).encode()
     headers = {"Content-Length": str(len(content))}
     return Response(200, headers=headers, content=content)
index 3d24ac7400672ced0c4d5a0b82fc8d1cd4007a40..d4cf6521c214a987df43e618d52214aee3de0227 100644 (file)
@@ -53,7 +53,9 @@ class MockHTTP2Server(BaseSocketStream):
         send, self.buffer = self.buffer[:n], self.buffer[n:]
         return send
 
-    def write_no_block(self, data: bytes) -> None:
+    async def write(self, data: bytes, timeout) -> None:
+        if not data:
+            return
         events = self.conn.receive_data(data)
         self.buffer += self.conn.data_to_send()
         for event in events:
@@ -73,13 +75,10 @@ class MockHTTP2Server(BaseSocketStream):
                     )
                     self.buffer += self.conn.data_to_send()
             elif isinstance(event, h2.events.StreamEnded):
-                self.stream_complete(event.stream_id)
+                await self.stream_complete(event.stream_id)
             elif isinstance(event, h2.events.RemoteSettingsChanged):
                 self.settings_changed.append(event)
 
-    async def write(self, data: bytes, timeout) -> None:
-        self.write_no_block(data)
-
     async def close(self) -> None:
         pass
 
@@ -102,7 +101,7 @@ class MockHTTP2Server(BaseSocketStream):
         """
         self.requests[stream_id][-1]["data"] += data
 
-    def stream_complete(self, stream_id):
+    async def stream_complete(self, stream_id):
         """
         Handler for when the HTTP request is completed.
         """
@@ -123,7 +122,7 @@ class MockHTTP2Server(BaseSocketStream):
 
         # Call out to the app.
         request = Request(method, url, headers=headers, data=data)
-        response = self.app(request)
+        response = await self.app(request)
 
         # Write the response to the buffer.
         status_code_bytes = str(response.status_code).encode("ascii")
@@ -192,13 +191,11 @@ class MockRawSocketStream(BaseSocketStream):
     def get_http_version(self) -> str:
         return "HTTP/1.1"
 
-    def write_no_block(self, data: bytes) -> None:
+    async def write(self, data: bytes, timeout) -> None:
+        if not data:
+            return
         self.backend.received_data.append(data)
 
-    async def write(self, data: bytes, timeout: Timeout = None) -> None:
-        if data:
-            self.write_no_block(data)
-
     async def read(self, n, timeout, flag=None) -> bytes:
         await sleep(self.backend.backend, 0)
         if not self.backend.data_to_send:
index 7835f98e72f8f370d8119635a7272b997ae39e10..6b3b6311d426e741dfb01b89b1df1cfd5e010c18 100644 (file)
@@ -18,16 +18,18 @@ def test_content_length_header():
     assert request.headers["Content-Length"] == "8"
 
 
-def test_url_encoded_data():
+@pytest.mark.asyncio
+async def test_url_encoded_data():
     request = httpx.Request("POST", "http://example.org", data={"test": "123"})
     assert request.headers["Content-Type"] == "application/x-www-form-urlencoded"
-    assert request.content == b"test=123"
+    assert await request.content.aread() == b"test=123"
 
 
-def test_json_encoded_data():
+@pytest.mark.asyncio
+async def test_json_encoded_data():
     request = httpx.Request("POST", "http://example.org", json={"test": 123})
     assert request.headers["Content-Type"] == "application/json"
-    assert request.content == b'{"test": 123}'
+    assert await request.content.aread() == b'{"test": 123}'
 
 
 def test_transfer_encoding_header():
diff --git a/tests/test_content.py b/tests/test_content.py
new file mode 100644 (file)
index 0000000..6714bc6
--- /dev/null
@@ -0,0 +1,109 @@
+import io
+
+import pytest
+
+from httpx.content import encode
+
+
+@pytest.mark.asyncio
+async def test_empty_content():
+    content = encode()
+
+    assert content.can_replay()
+    assert content.get_headers() == {}
+    assert await content.aread() == b""
+
+
+@pytest.mark.asyncio
+async def test_bytes_content():
+    content = encode(data=b"Hello, world!")
+
+    assert content.can_replay()
+    assert content.get_headers() == {"Content-Length": "13"}
+    assert await content.aread() == b"Hello, world!"
+
+
+@pytest.mark.asyncio
+async def test_aiterator_content():
+    async def hello_world():
+        yield b"Hello, "
+        yield b"world!"
+
+    content = encode(data=hello_world())
+
+    assert not content.can_replay()
+    assert content.get_headers() == {"Transfer-Encoding": "chunked"}
+    assert await content.aread() == b"Hello, world!"
+
+
+@pytest.mark.asyncio
+async def test_json_content():
+    content = encode(json={"Hello": "world!"})
+
+    assert content.can_replay()
+    assert content.get_headers() == {
+        "Content-Length": "19",
+        "Content-Type": "application/json",
+    }
+    assert await content.aread() == b'{"Hello": "world!"}'
+
+
+@pytest.mark.asyncio
+async def test_urlencoded_content():
+    content = encode(data={"Hello": "world!"})
+
+    assert content.can_replay()
+    assert content.get_headers() == {
+        "Content-Length": "14",
+        "Content-Type": "application/x-www-form-urlencoded",
+    }
+    assert await content.aread() == b"Hello=world%21"
+
+
+@pytest.mark.asyncio
+async def test_multipart_files_content():
+    files = {"file": io.BytesIO(b"<file content>")}
+    content = encode(files=files, boundary=b"+++")
+
+    assert content.can_replay()
+    assert content.get_headers() == {
+        "Content-Length": "138",
+        "Content-Type": "multipart/form-data; boundary=+++",
+    }
+    assert await content.aread() == b"".join(
+        [
+            b"--+++\r\n",
+            b'Content-Disposition: form-data; name="file"; filename="upload"\r\n',
+            b"Content-Type: application/octet-stream\r\n",
+            b"\r\n",
+            b"<file content>\r\n",
+            b"--+++--\r\n",
+        ]
+    )
+
+
+@pytest.mark.asyncio
+async def test_multipart_data_and_files_content():
+    data = {"message": "Hello, world!"}
+    files = {"file": io.BytesIO(b"<file content>")}
+    content = encode(data=data, files=files, boundary=b"+++")
+
+    assert content.can_replay()
+    assert content.get_headers() == {
+        "Content-Length": "210",
+        "Content-Type": "multipart/form-data; boundary=+++",
+    }
+    assert await content.aread() == b"".join(
+        [
+            b"--+++\r\n",
+            b'Content-Disposition: form-data; name="message"\r\n',
+            b"\r\n",
+            b"Hello, world!\r\n",
+            b"--+++\r\n",
+            b'Content-Disposition: form-data; name="file"; filename="upload"\r\n',
+            b"Content-Type: application/octet-stream\r\n",
+            b"\r\n",
+            b"<file content>\r\n",
+            b"--+++--\r\n",
+        ]
+    )