]> git.ipfire.org Git - thirdparty/httpx.git/commitdiff
Use Streams API for both requests and responses. (#648)
authorTom Christie <tom@tomchristie.com>
Fri, 20 Dec 2019 16:05:04 +0000 (16:05 +0000)
committerGitHub <noreply@github.com>
Fri, 20 Dec 2019 16:05:04 +0000 (16:05 +0000)
* Internal ContentStreams API

17 files changed:
httpx/client.py
httpx/content.py [deleted file]
httpx/content_streams.py [new file with mode: 0644]
httpx/dispatch/asgi.py
httpx/dispatch/http11.py
httpx/dispatch/http2.py
httpx/dispatch/proxy_http.py
httpx/models.py
httpx/multipart.py [deleted file]
httpx/utils.py
tests/client/test_redirects.py
tests/dispatch/test_http2.py
tests/models/test_requests.py
tests/models/test_responses.py
tests/test_content_streams.py [moved from tests/test_content.py with 53% similarity]
tests/test_decoders.py
tests/test_multipart.py

index 1ceea145d2c72fc8d676b12820bf0aef5c04d767..75343b8936bdbba84c3d5d82b0b86b60f5570f6f 100644 (file)
@@ -19,7 +19,7 @@ from .config import (
     UnsetType,
     VerifyTypes,
 )
-from .content import RequestContent
+from .content_streams import ContentStream
 from .dispatch.asgi import ASGIDispatch
 from .dispatch.base import Dispatcher
 from .dispatch.connection_pool import ConnectionPool
@@ -495,11 +495,11 @@ class Client:
         method = self.redirect_method(request, response)
         url = self.redirect_url(request, response)
         headers = self.redirect_headers(request, url, method)
-        content = self.redirect_content(request, method)
+        stream = self.redirect_stream(request, method)
         cookies = Cookies(self.cookies)
-        request = Request(method=method, url=url, headers=headers, cookies=cookies)
-        request.content = content
-        return request
+        return Request(
+            method=method, url=url, headers=headers, cookies=cookies, stream=stream
+        )
 
     def redirect_method(self, request: Request, response: Response) -> str:
         """
@@ -567,15 +567,17 @@ class Client:
 
         return headers
 
-    def redirect_content(self, request: Request, method: str) -> RequestContent:
+    def redirect_stream(
+        self, request: Request, method: str
+    ) -> typing.Optional[ContentStream]:
         """
         Return the body that should be used for the redirect request.
         """
         if method != request.method and method == "GET":
-            return RequestContent()
-        if not request.content.can_replay():
+            return None
+        if not request.stream.can_replay():
             raise RedirectBodyUnavailable()
-        return request.content
+        return request.stream
 
     async def send_handling_auth(
         self,
diff --git a/httpx/content.py b/httpx/content.py
deleted file mode 100644 (file)
index 7fe106d..0000000
+++ /dev/null
@@ -1,171 +0,0 @@
-import typing
-from json import dumps as json_dumps
-from urllib.parse import urlencode
-
-from .multipart import multipart_encode
-
-RequestData = typing.Union[dict, str, bytes, typing.AsyncIterator[bytes]]
-
-RequestFiles = typing.Dict[
-    str,
-    typing.Union[
-        # file (or str)
-        typing.Union[typing.IO[typing.AnyStr], typing.AnyStr],
-        # (filename, file (or str))
-        typing.Tuple[
-            typing.Optional[str], typing.Union[typing.IO[typing.AnyStr], typing.AnyStr],
-        ],
-        # (filename, file (or str), content_type)
-        typing.Tuple[
-            typing.Optional[str],
-            typing.Union[typing.IO[typing.AnyStr], typing.AnyStr],
-            typing.Optional[str],
-        ],
-    ],
-]
-
-
-class RequestContent:
-    """
-    Base class for request content.
-    Defaults to a "no request body" implementation.
-    """
-
-    def get_headers(self) -> typing.Dict[str, str]:
-        """
-        Return a dictionary of request headers that are implied by the encoding.
-        """
-        return {}
-
-    def can_replay(self) -> bool:
-        """
-        Return `True` if `__aiter__` can be called multiple times.
-
-        We need this in order to determine if we can re-issue a request body
-        when we receive a redirect response.
-        """
-        return True
-
-    async def __aiter__(self) -> typing.AsyncIterator[bytes]:
-        yield b""
-
-    async def aread(self) -> bytes:
-        return b"".join([part async for part in self])
-
-
-class BytesRequestContent(RequestContent):
-    """
-    Request content encoded as plain bytes.
-    """
-
-    def __init__(self, body: typing.Union[str, bytes]) -> None:
-        self.body = body.encode("utf-8") if isinstance(body, str) else body
-
-    def get_headers(self) -> typing.Dict[str, str]:
-        content_length = str(len(self.body))
-        return {"Content-Length": content_length}
-
-    async def __aiter__(self) -> typing.AsyncIterator[bytes]:
-        yield self.body
-
-
-class StreamingRequestContent(RequestContent):
-    """
-    Request content encoded as plain bytes, using an async byte iterator.
-    """
-
-    def __init__(self, aiterator: typing.AsyncIterator[bytes]) -> None:
-        self.aiterator = aiterator
-
-    def can_replay(self) -> bool:
-        return False
-
-    def get_headers(self) -> typing.Dict[str, str]:
-        return {"Transfer-Encoding": "chunked"}
-
-    async def __aiter__(self) -> typing.AsyncIterator[bytes]:
-        async for part in self.aiterator:
-            yield part
-
-
-class JSONRequestContent(RequestContent):
-    """
-    Request content encoded as JSON.
-    """
-
-    def __init__(self, json: typing.Any) -> None:
-        self.body = json_dumps(json).encode("utf-8")
-
-    def get_headers(self) -> typing.Dict[str, str]:
-        content_length = str(len(self.body))
-        content_type = "application/json"
-        return {"Content-Length": content_length, "Content-Type": content_type}
-
-    async def __aiter__(self) -> typing.AsyncIterator[bytes]:
-        yield self.body
-
-
-class URLEncodedRequestContent(RequestContent):
-    """
-    Request content as URL encoded form data.
-    """
-
-    def __init__(self, data: dict) -> None:
-        self.body = urlencode(data, doseq=True).encode("utf-8")
-
-    def get_headers(self) -> typing.Dict[str, str]:
-        content_length = str(len(self.body))
-        content_type = "application/x-www-form-urlencoded"
-        return {"Content-Length": content_length, "Content-Type": content_type}
-
-    async def __aiter__(self) -> typing.AsyncIterator[bytes]:
-        yield self.body
-
-
-class MultipartRequestContent(RequestContent):
-    """
-    Request content as multipart encoded form data.
-    """
-
-    def __init__(self, data: dict, files: dict, boundary: bytes = None) -> None:
-        self.body, self.content_type = multipart_encode(data, files, boundary)
-
-    def get_headers(self) -> typing.Dict[str, str]:
-        content_length = str(len(self.body))
-        content_type = self.content_type
-        return {"Content-Length": content_length, "Content-Type": content_type}
-
-    async def __aiter__(self) -> typing.AsyncIterator[bytes]:
-        yield self.body
-
-
-def encode(
-    data: RequestData = None,
-    files: RequestFiles = None,
-    json: typing.Any = None,
-    boundary: bytes = None,
-) -> RequestContent:
-    """
-    Handles encoding the given `data`, `files`, and `json`, returning
-    a `RequestContent` implementation which provides a byte iterator onto
-    the content, as well as `.is_rewindable()` and `.get_headers()` interfaces.
-
-    The `boundary` argument is also included for reproducible test cases
-    when working with multipart data.
-    """
-    if data is None:
-        if json is not None:
-            return JSONRequestContent(json)
-        elif files:
-            return MultipartRequestContent({}, files, boundary=boundary)
-        else:
-            return RequestContent()
-    elif isinstance(data, dict):
-        if files is not None:
-            return MultipartRequestContent(data, files, boundary=boundary)
-        else:
-            return URLEncodedRequestContent(data)
-    elif isinstance(data, (str, bytes)):
-        return BytesRequestContent(data)
-    else:
-        return StreamingRequestContent(data)
diff --git a/httpx/content_streams.py b/httpx/content_streams.py
new file mode 100644 (file)
index 0000000..62d150b
--- /dev/null
@@ -0,0 +1,279 @@
+import binascii
+import mimetypes
+import os
+import typing
+from io import BytesIO
+from json import dumps as json_dumps
+from pathlib import Path
+from urllib.parse import urlencode
+
+from .utils import format_form_param
+
+RequestData = typing.Union[dict, str, bytes, typing.AsyncIterator[bytes]]
+
+RequestFiles = typing.Dict[
+    str,
+    typing.Union[
+        # file (or str)
+        typing.Union[typing.IO[typing.AnyStr], typing.AnyStr],
+        # (filename, file (or str))
+        typing.Tuple[
+            typing.Optional[str], typing.Union[typing.IO[typing.AnyStr], typing.AnyStr],
+        ],
+        # (filename, file (or str), content_type)
+        typing.Tuple[
+            typing.Optional[str],
+            typing.Union[typing.IO[typing.AnyStr], typing.AnyStr],
+            typing.Optional[str],
+        ],
+    ],
+]
+
+
+class ContentStream:
+    def get_headers(self) -> typing.Dict[str, str]:
+        """
+        Return a dictionary of headers that are implied by the encoding.
+        """
+        return {}
+
+    def can_replay(self) -> bool:
+        """
+        Return `True` if `__aiter__` can be called multiple times.
+
+        We need this in cases such determining if we can re-issue a request
+        body when we receive a redirect response.
+        """
+        return True
+
+    async def __aiter__(self) -> typing.AsyncIterator[bytes]:
+        yield b""
+
+    async def aclose(self) -> None:
+        pass
+
+
+class ByteStream(ContentStream):
+    """
+    Request content encoded as plain bytes.
+    """
+
+    def __init__(self, body: typing.Union[str, bytes]) -> None:
+        self.body = body.encode("utf-8") if isinstance(body, str) else body
+
+    def get_headers(self) -> typing.Dict[str, str]:
+        if not self.body:
+            return {}
+        content_length = str(len(self.body))
+        return {"Content-Length": content_length}
+
+    async def __aiter__(self) -> typing.AsyncIterator[bytes]:
+        yield self.body
+
+
+class AsyncIteratorStream(ContentStream):
+    """
+    Request content encoded as plain bytes, using an async byte iterator.
+    """
+
+    def __init__(
+        self, aiterator: typing.AsyncIterator[bytes], close_func: typing.Callable = None
+    ) -> None:
+        self.aiterator = aiterator
+        self.close_func = close_func
+
+    def can_replay(self) -> bool:
+        return False
+
+    def get_headers(self) -> typing.Dict[str, str]:
+        return {"Transfer-Encoding": "chunked"}
+
+    async def __aiter__(self) -> typing.AsyncIterator[bytes]:
+        async for part in self.aiterator:
+            yield part
+
+    async def aclose(self) -> None:
+        if self.close_func is not None:
+            await self.close_func()
+
+
+class JSONStream(ContentStream):
+    """
+    Request content encoded as JSON.
+    """
+
+    def __init__(self, json: typing.Any) -> None:
+        self.body = json_dumps(json).encode("utf-8")
+
+    def get_headers(self) -> typing.Dict[str, str]:
+        content_length = str(len(self.body))
+        content_type = "application/json"
+        return {"Content-Length": content_length, "Content-Type": content_type}
+
+    async def __aiter__(self) -> typing.AsyncIterator[bytes]:
+        yield self.body
+
+
+class URLEncodedStream(ContentStream):
+    """
+    Request content as URL encoded form data.
+    """
+
+    def __init__(self, data: dict) -> None:
+        self.body = urlencode(data, doseq=True).encode("utf-8")
+
+    def get_headers(self) -> typing.Dict[str, str]:
+        content_length = str(len(self.body))
+        content_type = "application/x-www-form-urlencoded"
+        return {"Content-Length": content_length, "Content-Type": content_type}
+
+    async def __aiter__(self) -> typing.AsyncIterator[bytes]:
+        yield self.body
+
+
+class MultipartStream(ContentStream):
+    """
+    Request content as multipart encoded form data.
+    """
+
+    class DataField:
+        """
+        A single form field item, within a multipart form field.
+        """
+
+        def __init__(self, name: str, value: typing.Union[str, bytes]) -> None:
+            if not isinstance(name, str):
+                raise TypeError("Invalid type for name. Expected str.")
+            if not isinstance(value, (str, bytes)):
+                raise TypeError("Invalid type for value. Expected str or bytes.")
+            self.name = name
+            self.value = value
+
+        def render_headers(self) -> bytes:
+            name = format_form_param("name", self.name)
+            return b"".join([b"Content-Disposition: form-data; ", name, b"\r\n\r\n"])
+
+        def render_data(self) -> bytes:
+            return (
+                self.value
+                if isinstance(self.value, bytes)
+                else self.value.encode("utf-8")
+            )
+
+    class FileField:
+        """
+        A single file field item, within a multipart form field.
+        """
+
+        def __init__(
+            self, name: str, value: typing.Union[typing.IO[typing.AnyStr], tuple]
+        ) -> None:
+            self.name = name
+            if not isinstance(value, tuple):
+                self.filename = Path(str(getattr(value, "name", "upload"))).name
+                self.file = (
+                    value
+                )  # type: typing.Union[typing.IO[str], typing.IO[bytes]]
+                self.content_type = self.guess_content_type()
+            else:
+                self.filename = value[0]
+                self.file = value[1]
+                self.content_type = (
+                    value[2] if len(value) > 2 else self.guess_content_type()
+                )
+
+        def guess_content_type(self) -> typing.Optional[str]:
+            if self.filename:
+                return (
+                    mimetypes.guess_type(self.filename)[0] or "application/octet-stream"
+                )
+            else:
+                return None
+
+        def render_headers(self) -> bytes:
+            parts = [
+                b"Content-Disposition: form-data; ",
+                format_form_param("name", self.name),
+            ]
+            if self.filename:
+                filename = format_form_param("filename", self.filename)
+                parts.extend([b"; ", filename])
+            if self.content_type is not None:
+                content_type = self.content_type.encode()
+                parts.extend([b"\r\nContent-Type: ", content_type])
+            parts.append(b"\r\n\r\n")
+            return b"".join(parts)
+
+        def render_data(self) -> bytes:
+            if isinstance(self.file, str):
+                content = self.file
+            else:
+                content = self.file.read()
+            return content.encode("utf-8") if isinstance(content, str) else content
+
+    def __init__(self, data: dict, files: dict, boundary: bytes = None) -> None:
+        body = BytesIO()
+        if boundary is None:
+            boundary = binascii.hexlify(os.urandom(16))
+
+        for field in self.iter_fields(data, files):
+            body.write(b"--%s\r\n" % boundary)
+            body.write(field.render_headers())
+            body.write(field.render_data())
+            body.write(b"\r\n")
+
+        body.write(b"--%s--\r\n" % boundary)
+
+        self.content_type = "multipart/form-data; boundary=%s" % boundary.decode(
+            "ascii"
+        )
+        self.body = body.getvalue()
+
+    def iter_fields(
+        self, data: dict, files: dict
+    ) -> typing.Iterator[typing.Union["FileField", "DataField"]]:
+        for name, value in data.items():
+            if isinstance(value, (list, dict)):
+                for item in value:
+                    yield self.DataField(name=name, value=item)
+            else:
+                yield self.DataField(name=name, value=value)
+
+        for name, value in files.items():
+            yield self.FileField(name=name, value=value)
+
+    def get_headers(self) -> typing.Dict[str, str]:
+        content_length = str(len(self.body))
+        content_type = self.content_type
+        return {"Content-Length": content_length, "Content-Type": content_type}
+
+    async def __aiter__(self) -> typing.AsyncIterator[bytes]:
+        yield self.body
+
+
+def encode(
+    data: RequestData = None,
+    files: RequestFiles = None,
+    json: typing.Any = None,
+    boundary: bytes = None,
+) -> ContentStream:
+    """
+    Handles encoding the given `data`, `files`, and `json`, returning
+    a `ContentStream` implementation.
+    """
+    if data is None:
+        if json is not None:
+            return JSONStream(json=json)
+        elif files:
+            return MultipartStream(data={}, files=files, boundary=boundary)
+        else:
+            return ByteStream(body=b"")
+    elif isinstance(data, dict):
+        if files is not None:
+            return MultipartStream(data=data, files=files, boundary=boundary)
+        else:
+            return URLEncodedStream(data=data)
+    elif isinstance(data, (str, bytes)):
+        return ByteStream(body=data)
+    else:
+        return AsyncIteratorStream(aiterator=data)
index a1e83255951dce4b427ec529e7119474d6ac87af..4e55c21aa9a817ec1e404dfcdb1a39a937d8af63 100644 (file)
@@ -1,6 +1,7 @@
 import typing
 
 from ..config import CertTypes, TimeoutTypes, VerifyTypes
+from ..content_streams import ByteStream
 from ..models import Request, Response
 from .base import Dispatcher
 
@@ -77,13 +78,14 @@ class ASGIDispatch(Dispatcher):
         status_code = None
         headers = None
         body_parts = []
-        request_stream = request.stream()
         response_started = False
         response_complete = False
 
+        request_body_chunks = request.stream.__aiter__()
+
         async def receive() -> dict:
             try:
-                body = await request_stream.__anext__()
+                body = await request_body_chunks.__anext__()
             except StopAsyncIteration:
                 return {"type": "http.request", "body": b"", "more_body": False}
             return {"type": "http.request", "body": body, "more_body": True}
@@ -120,10 +122,12 @@ class ASGIDispatch(Dispatcher):
         assert status_code is not None
         assert headers is not None
 
+        stream = ByteStream(b"".join(body_parts))
+
         return Response(
             status_code=status_code,
             http_version="HTTP/1.1",
             headers=headers,
-            content=b"".join(body_parts),
+            stream=stream,
             request=request,
         )
index 4367b13f9db5335fab73ad3d7adcb52849400865..384a8106ebbcf2ecbe6c4f72e2789abdf0c6f12d 100644 (file)
@@ -4,6 +4,7 @@ import h11
 
 from ..concurrency.base import BaseSocketStream
 from ..config import Timeout
+from ..content_streams import AsyncIteratorStream
 from ..exceptions import ConnectionClosed, ProtocolError
 from ..models import Request, Response
 from ..utils import get_logger
@@ -50,14 +51,16 @@ class HTTP11Connection(OpenConnection):
         await self._send_request(request, timeout)
         await self._send_request_body(request, timeout)
         http_version, status_code, headers = await self._receive_response(timeout)
-        content = self._receive_response_data(timeout)
+        stream = AsyncIteratorStream(
+            aiterator=self._receive_response_data(timeout),
+            close_func=self.response_closed,
+        )
 
         return Response(
             status_code=status_code,
             http_version=http_version,
             headers=headers,
-            content=content,
-            on_close=self.response_closed,
+            stream=stream,
             request=request,
         )
 
@@ -93,7 +96,7 @@ class HTTP11Connection(OpenConnection):
         """
         try:
             # Send the request body.
-            async for chunk in request.stream():
+            async for chunk in request.stream:
                 logger.trace(f"send_data data=Data(<{len(chunk)} bytes>)")
                 event = h11.Data(data=chunk)
                 await self._send_event(event, timeout)
index 24ed83dd901298697d36189d46a8c21a0c2c1fb1..6748c59188437985d7098c0f2310d225ab76dd9b 100644 (file)
@@ -12,6 +12,7 @@ from ..concurrency.base import (
     lookup_backend,
 )
 from ..config import Timeout
+from ..content_streams import AsyncIteratorStream
 from ..exceptions import ProtocolError
 from ..models import Request, Response
 from ..utils import get_logger
@@ -209,13 +210,15 @@ class HTTP2Stream:
 
         # Receive the response.
         status_code, headers = await self.receive_response(timeout)
-        content = self.body_iter(timeout)
+        stream = AsyncIteratorStream(
+            aiterator=self.body_iter(timeout), close_func=self.close
+        )
+
         return Response(
             status_code=status_code,
             http_version="HTTP/2",
             headers=headers,
-            content=content,
-            on_close=self.close,
+            stream=stream,
             request=request,
         )
 
@@ -238,7 +241,7 @@ class HTTP2Stream:
 
     async def send_body(self, request: Request, timeout: Timeout) -> None:
         logger.trace(f"send_body stream_id={self.stream_id}")
-        async for data in request.stream():
+        async for data in request.stream:
             while data:
                 max_flow = await self.connection.wait_for_outgoing_flow(
                     self.stream_id, timeout
index 7da35f5b71382dc1a5509e6245c740e821cd63c2..4497f4faeaa833d372be03d5cca4ecf56f6aeabc 100644 (file)
@@ -167,8 +167,9 @@ class HTTPProxy(ConnectionPool):
                 response=proxy_response,
             )
         else:
-            proxy_response.on_close = None
-            await proxy_response.read()
+            # Hack to ingest the response, without closing it.
+            async for chunk in proxy_response._raw_stream:
+                pass
 
         return connection
 
index fc938c3ebf67e31bbd7803aff3c6d7273c8c0103..93e6492fa62df0fdce513d9a0a44867db51fed7e 100644 (file)
@@ -13,7 +13,7 @@ import chardet
 import rfc3986
 
 from .config import USER_AGENT
-from .content import RequestData, RequestFiles, encode
+from .content_streams import ContentStream, RequestData, RequestFiles, encode
 from .decoders import (
     ACCEPT_ENCODING,
     SUPPORTED_DECODERS,
@@ -71,8 +71,6 @@ ProxiesTypes = typing.Union[
     URLTypes, "Dispatcher", typing.Dict[URLTypes, typing.Union[URLTypes, "Dispatcher"]]
 ]
 
-ResponseContent = typing.Union[bytes, typing.AsyncIterator[bytes]]
-
 
 class URL:
     def __init__(
@@ -595,6 +593,7 @@ class Request:
         data: RequestData = None,
         files: RequestFiles = None,
         json: typing.Any = None,
+        stream: ContentStream = None,
     ):
         self.method = method.upper()
         self.url = URL(url, params=params)
@@ -602,11 +601,16 @@ class Request:
         if cookies:
             self._cookies = Cookies(cookies)
             self._cookies.set_cookie_header(self)
-        self.content = encode(data, files, json)
+
+        if stream is not None:
+            self.stream = stream
+        else:
+            self.stream = encode(data, files, json)
+
         self.prepare()
 
     def prepare(self) -> None:
-        for key, value in self.content.get_headers().items():
+        for key, value in self.stream.get_headers().items():
             self.headers.setdefault(key, value)
 
         auto_headers: typing.List[typing.Tuple[bytes, bytes]] = []
@@ -649,11 +653,7 @@ class Request:
         """
         Read and return the request content.
         """
-        return await self.content.aread()
-
-    async def stream(self) -> typing.AsyncIterator[bytes]:
-        async for part in self.content:
-            yield part
+        return b"".join([part async for part in self.stream])
 
 
 class Response:
@@ -663,8 +663,8 @@ class Response:
         *,
         http_version: str = None,
         headers: HeaderTypes = None,
-        content: ResponseContent = None,
-        on_close: typing.Callable = None,
+        stream: ContentStream = None,
+        content: bytes = None,
         request: Request = None,
         history: typing.List["Response"] = None,
         elapsed: datetime.timedelta = None,
@@ -674,20 +674,19 @@ class Response:
         self.headers = Headers(headers)
 
         self.request = request
-        self.on_close = on_close
         self.elapsed = datetime.timedelta(0) if elapsed is None else elapsed
         self.call_next: typing.Optional[typing.Callable] = None
 
         self.history = [] if history is None else list(history)
 
-        if content is None or isinstance(content, bytes):
+        if stream is None:
             self.is_closed = True
             self.is_stream_consumed = True
             self._raw_content = content or b""
         else:
             self.is_closed = False
             self.is_stream_consumed = False
-            self._raw_stream = content
+            self._raw_stream = stream
 
     @property
     def reason_phrase(self) -> str:
@@ -942,8 +941,8 @@ class Response:
         """
         if not self.is_closed:
             self.is_closed = True
-            if self.on_close is not None:
-                await self.on_close()
+            if hasattr(self, "_raw_stream"):
+                await self._raw_stream.aclose()
 
 
 class Cookies(MutableMapping):
diff --git a/httpx/multipart.py b/httpx/multipart.py
deleted file mode 100644 (file)
index 1f0f4a3..0000000
+++ /dev/null
@@ -1,126 +0,0 @@
-import binascii
-import mimetypes
-import os
-import re
-import typing
-from io import BytesIO
-from pathlib import Path
-
-_HTML5_FORM_ENCODING_REPLACEMENTS = {'"': "%22", "\\": "\\\\"}
-_HTML5_FORM_ENCODING_REPLACEMENTS.update(
-    {chr(c): "%{:02X}".format(c) for c in range(0x00, 0x1F + 1) if c != 0x1B}
-)
-_HTML5_FORM_ENCODING_RE = re.compile(
-    r"|".join([re.escape(c) for c in _HTML5_FORM_ENCODING_REPLACEMENTS.keys()])
-)
-
-
-class Field:
-    def render_headers(self) -> bytes:
-        raise NotImplementedError()  # pragma: nocover
-
-    def render_data(self) -> bytes:
-        raise NotImplementedError()  # pragma: nocover
-
-
-class DataField(Field):
-    def __init__(self, name: str, value: typing.Union[str, bytes]) -> None:
-        if not isinstance(name, str):
-            raise TypeError("Invalid type for name. Expected str.")
-        if not isinstance(value, (str, bytes)):
-            raise TypeError("Invalid type for value. Expected str or bytes.")
-        self.name = name
-        self.value = value
-
-    def render_headers(self) -> bytes:
-        name = _format_param("name", self.name)
-        return b"".join([b"Content-Disposition: form-data; ", name, b"\r\n\r\n"])
-
-    def render_data(self) -> bytes:
-        return (
-            self.value if isinstance(self.value, bytes) else self.value.encode("utf-8")
-        )
-
-
-class FileField(Field):
-    def __init__(
-        self, name: str, value: typing.Union[typing.IO[typing.AnyStr], tuple]
-    ) -> None:
-        self.name = name
-        if not isinstance(value, tuple):
-            self.filename = Path(str(getattr(value, "name", "upload"))).name
-            self.file = value  # type: typing.Union[typing.IO[str], typing.IO[bytes]]
-            self.content_type = self.guess_content_type()
-        else:
-            self.filename = value[0]
-            self.file = value[1]
-            self.content_type = (
-                value[2] if len(value) > 2 else self.guess_content_type()
-            )
-
-    def guess_content_type(self) -> typing.Optional[str]:
-        if self.filename:
-            return mimetypes.guess_type(self.filename)[0] or "application/octet-stream"
-        else:
-            return None
-
-    def render_headers(self) -> bytes:
-        parts = [b"Content-Disposition: form-data; ", _format_param("name", self.name)]
-        if self.filename:
-            filename = _format_param("filename", self.filename)
-            parts.extend([b"; ", filename])
-        if self.content_type is not None:
-            content_type = self.content_type.encode()
-            parts.extend([b"\r\nContent-Type: ", content_type])
-        parts.append(b"\r\n\r\n")
-        return b"".join(parts)
-
-    def render_data(self) -> bytes:
-        if isinstance(self.file, str):
-            content = self.file
-        else:
-            content = self.file.read()
-        return content.encode("utf-8") if isinstance(content, str) else content
-
-
-def iter_fields(data: dict, files: dict) -> typing.Iterator[Field]:
-    for name, value in data.items():
-        if isinstance(value, (list, dict)):
-            for item in value:
-                yield DataField(name=name, value=item)
-        else:
-            yield DataField(name=name, value=value)
-
-    for name, value in files.items():
-        yield FileField(name=name, value=value)
-
-
-def multipart_encode(
-    data: dict, files: dict, boundary: bytes = None
-) -> typing.Tuple[bytes, str]:
-    body = BytesIO()
-    if boundary is None:
-        boundary = binascii.hexlify(os.urandom(16))
-
-    for field in iter_fields(data, files):
-        body.write(b"--%s\r\n" % boundary)
-        body.write(field.render_headers())
-        body.write(field.render_data())
-        body.write(b"\r\n")
-
-    body.write(b"--%s--\r\n" % boundary)
-
-    content_type = "multipart/form-data; boundary=%s" % boundary.decode("ascii")
-
-    return body.getvalue(), content_type
-
-
-def _format_param(name: str, value: typing.Union[str, bytes]) -> bytes:
-    if isinstance(value, bytes):
-        value = value.decode()
-
-    def replacer(match: typing.Match[str]) -> str:
-        return _HTML5_FORM_ENCODING_REPLACEMENTS[match.group(0)]
-
-    value = _HTML5_FORM_ENCODING_RE.sub(replacer, value)
-    return f'{name}="{value}"'.encode()
index 9d3aa2230eaddc12d316962d287f97520c51428c..2a6aa3f99ab883ad9313c23b2f84dd61419dfa6b 100644 (file)
@@ -17,6 +17,15 @@ if typing.TYPE_CHECKING:  # pragma: no cover
     from .models import URL
 
 
+_HTML5_FORM_ENCODING_REPLACEMENTS = {'"': "%22", "\\": "\\\\"}
+_HTML5_FORM_ENCODING_REPLACEMENTS.update(
+    {chr(c): "%{:02X}".format(c) for c in range(0x00, 0x1F + 1) if c != 0x1B}
+)
+_HTML5_FORM_ENCODING_RE = re.compile(
+    r"|".join([re.escape(c) for c in _HTML5_FORM_ENCODING_REPLACEMENTS.keys()])
+)
+
+
 def normalize_header_key(value: typing.AnyStr, encoding: str = None) -> bytes:
     """
     Coerce str/bytes into a strictly byte-wise HTTP header key.
@@ -61,6 +70,20 @@ def is_known_encoding(encoding: str) -> bool:
     return True
 
 
+def format_form_param(name: str, value: typing.Union[str, bytes]) -> bytes:
+    """
+    Encode a name/value pair within a multipart form.
+    """
+    if isinstance(value, bytes):
+        value = value.decode()
+
+    def replacer(match: typing.Match[str]) -> str:
+        return _HTML5_FORM_ENCODING_REPLACEMENTS[match.group(0)]
+
+    value = _HTML5_FORM_ENCODING_RE.sub(replacer, value)
+    return f'{name}="{value}"'.encode()
+
+
 # Null bytes; no need to recreate these on each call to guess_json_utf
 _null = "\x00".encode("ascii")  # encoding to ASCII for Python 3
 _null2 = _null * 2
index 7208a1cfec1466c72ef443a2fff9da43d77dce3a..4eb0ba1065d3dd0ce2bc506ec165e86ca570a463 100644 (file)
@@ -81,12 +81,12 @@ class MockDispatch(Dispatcher):
             return Response(codes.PERMANENT_REDIRECT, headers=headers, request=request)
 
         elif request.url.path == "/redirect_no_body":
-            await request.read()
+            content = b"".join([part async for part in request.stream])
             headers = {"location": "/redirect_body_target"}
             return Response(codes.SEE_OTHER, headers=headers, request=request)
 
         elif request.url.path == "/redirect_body_target":
-            content = await request.read()
+            content = b"".join([part async for part in request.stream])
             headers = dict(request.headers.items())
             body = json.dumps({"body": content.decode(), "headers": headers}).encode()
             return Response(codes.OK, content=body, request=request)
index 5aae075d17e704099ec2eaf3ecf563ac5ab96ba6..e1580045e0b67a4d3bb657e0ddd9f755054c8e92 100644 (file)
@@ -14,7 +14,7 @@ from .utils import MockHTTP2Backend
 async def app(request):
     method = request.method
     path = request.url.path
-    body = await request.read()
+    body = b"".join([part async for part in request.stream])
     content = json.dumps(
         {"method": method, "path": path, "body": body.decode()}
     ).encode()
index 6b3b6311d426e741dfb01b89b1df1cfd5e010c18..43afe0436089964a38bf38e21243ad4a180d2ce6 100644 (file)
@@ -21,15 +21,19 @@ def test_content_length_header():
 @pytest.mark.asyncio
 async def test_url_encoded_data():
     request = httpx.Request("POST", "http://example.org", data={"test": "123"})
+    content = b"".join([part async for part in request.stream])
+
     assert request.headers["Content-Type"] == "application/x-www-form-urlencoded"
-    assert await request.content.aread() == b"test=123"
+    assert content == b"test=123"
 
 
 @pytest.mark.asyncio
 async def test_json_encoded_data():
     request = httpx.Request("POST", "http://example.org", json={"test": 123})
+    content = b"".join([part async for part in request.stream])
+
     assert request.headers["Content-Type"] == "application/json"
-    assert await request.content.aread() == b'{"test": 123}'
+    assert content == b'{"test": 123}'
 
 
 def test_transfer_encoding_header():
index a5ae254e08f782675262f5ef0dc37665ab86e5fb..7b0281b08507db426a0d4fe01372a1be72012bb2 100644 (file)
@@ -5,6 +5,7 @@ from unittest import mock
 import pytest
 
 import httpx
+from httpx.content_streams import AsyncIteratorStream
 
 
 def streaming_body():
@@ -190,7 +191,8 @@ async def test_stream_interface_after_read():
 
 @pytest.mark.asyncio
 async def test_streaming_response():
-    response = httpx.Response(200, content=async_streaming_body())
+    stream = AsyncIteratorStream(aiterator=async_streaming_body())
+    response = httpx.Response(200, stream=stream)
 
     assert response.status_code == 200
     assert not response.is_closed
@@ -204,7 +206,8 @@ async def test_streaming_response():
 
 @pytest.mark.asyncio
 async def test_cannot_read_after_stream_consumed():
-    response = httpx.Response(200, content=async_streaming_body())
+    stream = AsyncIteratorStream(aiterator=async_streaming_body())
+    response = httpx.Response(200, stream=stream)
 
     content = b""
     async for part in response.aiter_bytes():
@@ -216,7 +219,8 @@ async def test_cannot_read_after_stream_consumed():
 
 @pytest.mark.asyncio
 async def test_cannot_read_after_response_closed():
-    response = httpx.Response(200, content=async_streaming_body())
+    stream = AsyncIteratorStream(aiterator=async_streaming_body())
+    response = httpx.Response(200, stream=stream)
 
     await response.close()
 
similarity index 53%
rename from tests/test_content.py
rename to tests/test_content_streams.py
index 6714bc61697ea7ab6736b2c47980a5a12b42f022..64146ed941f535a8b84f4b3177feafe73bf4e33e 100644 (file)
@@ -2,25 +2,27 @@ import io
 
 import pytest
 
-from httpx.content import encode
+from httpx.content_streams import encode
 
 
 @pytest.mark.asyncio
 async def test_empty_content():
-    content = encode()
+    stream = encode()
+    content = b"".join([part async for part in stream])
 
-    assert content.can_replay()
-    assert content.get_headers() == {}
-    assert await content.aread() == b""
+    assert stream.can_replay()
+    assert stream.get_headers() == {}
+    assert content == b""
 
 
 @pytest.mark.asyncio
 async def test_bytes_content():
-    content = encode(data=b"Hello, world!")
+    stream = encode(data=b"Hello, world!")
+    content = b"".join([part async for part in stream])
 
-    assert content.can_replay()
-    assert content.get_headers() == {"Content-Length": "13"}
-    assert await content.aread() == b"Hello, world!"
+    assert stream.can_replay()
+    assert stream.get_headers() == {"Content-Length": "13"}
+    assert content == b"Hello, world!"
 
 
 @pytest.mark.asyncio
@@ -29,48 +31,52 @@ async def test_aiterator_content():
         yield b"Hello, "
         yield b"world!"
 
-    content = encode(data=hello_world())
+    stream = encode(data=hello_world())
+    content = b"".join([part async for part in stream])
 
-    assert not content.can_replay()
-    assert content.get_headers() == {"Transfer-Encoding": "chunked"}
-    assert await content.aread() == b"Hello, world!"
+    assert not stream.can_replay()
+    assert stream.get_headers() == {"Transfer-Encoding": "chunked"}
+    assert content == b"Hello, world!"
 
 
 @pytest.mark.asyncio
 async def test_json_content():
-    content = encode(json={"Hello": "world!"})
+    stream = encode(json={"Hello": "world!"})
+    content = b"".join([part async for part in stream])
 
-    assert content.can_replay()
-    assert content.get_headers() == {
+    assert stream.can_replay()
+    assert stream.get_headers() == {
         "Content-Length": "19",
         "Content-Type": "application/json",
     }
-    assert await content.aread() == b'{"Hello": "world!"}'
+    assert content == b'{"Hello": "world!"}'
 
 
 @pytest.mark.asyncio
 async def test_urlencoded_content():
-    content = encode(data={"Hello": "world!"})
+    stream = encode(data={"Hello": "world!"})
+    content = b"".join([part async for part in stream])
 
-    assert content.can_replay()
-    assert content.get_headers() == {
+    assert stream.can_replay()
+    assert stream.get_headers() == {
         "Content-Length": "14",
         "Content-Type": "application/x-www-form-urlencoded",
     }
-    assert await content.aread() == b"Hello=world%21"
+    assert content == b"Hello=world%21"
 
 
 @pytest.mark.asyncio
 async def test_multipart_files_content():
     files = {"file": io.BytesIO(b"<file content>")}
-    content = encode(files=files, boundary=b"+++")
+    stream = encode(files=files, boundary=b"+++")
+    content = b"".join([part async for part in stream])
 
-    assert content.can_replay()
-    assert content.get_headers() == {
+    assert stream.can_replay()
+    assert stream.get_headers() == {
         "Content-Length": "138",
         "Content-Type": "multipart/form-data; boundary=+++",
     }
-    assert await content.aread() == b"".join(
+    assert content == b"".join(
         [
             b"--+++\r\n",
             b'Content-Disposition: form-data; name="file"; filename="upload"\r\n',
@@ -86,14 +92,15 @@ async def test_multipart_files_content():
 async def test_multipart_data_and_files_content():
     data = {"message": "Hello, world!"}
     files = {"file": io.BytesIO(b"<file content>")}
-    content = encode(data=data, files=files, boundary=b"+++")
+    stream = encode(data=data, files=files, boundary=b"+++")
+    content = b"".join([part async for part in stream])
 
-    assert content.can_replay()
-    assert content.get_headers() == {
+    assert stream.can_replay()
+    assert stream.get_headers() == {
         "Content-Length": "210",
         "Content-Type": "multipart/form-data; boundary=+++",
     }
-    assert await content.aread() == b"".join(
+    assert content == b"".join(
         [
             b"--+++\r\n",
             b'Content-Disposition: form-data; name="message"\r\n',
index 76eeca0e320316d343b0412a9910051010db8948..feea42e5e3e727390a823a74cf65b747d68f74de 100644 (file)
@@ -4,6 +4,7 @@ import brotli
 import pytest
 
 import httpx
+from httpx.content_streams import AsyncIteratorStream
 from httpx.decoders import (
     BrotliDecoder,
     DeflateDecoder,
@@ -82,7 +83,8 @@ async def test_streaming():
         yield compressor.flush()
 
     headers = [(b"Content-Encoding", b"gzip")]
-    response = httpx.Response(200, headers=headers, content=compress(body))
+    stream = AsyncIteratorStream(aiterator=compress(body))
+    response = httpx.Response(200, headers=headers, stream=stream)
     assert not hasattr(response, "body")
     assert await response.read() == body
 
@@ -137,7 +139,8 @@ async def test_text_decoder(data, encoding):
         for chunk in data:
             yield chunk
 
-    response = httpx.Response(200, content=iterator())
+    stream = AsyncIteratorStream(aiterator=iterator())
+    response = httpx.Response(200, stream=stream)
     await response.read()
     assert response.text == (b"".join(data)).decode(encoding)
 
@@ -149,10 +152,11 @@ async def test_text_decoder_known_encoding():
         yield b"\x83"
         yield b"\x89\x83x\x83\x8b"
 
+    stream = AsyncIteratorStream(aiterator=iterator())
     response = httpx.Response(
         200,
         headers=[(b"Content-Type", b"text/html; charset=shift-jis")],
-        content=iterator(),
+        stream=stream,
     )
 
     await response.read()
index 718fb4565b53fd9b2600b28dd89c129b0573b664..a47a7b31d81bcc9b5d0123b75b4d9f930f1c830b 100644 (file)
@@ -8,8 +8,9 @@ import pytest
 
 import httpx
 from httpx.config import CertTypes, TimeoutTypes, VerifyTypes
+from httpx.content_streams import encode
 from httpx.dispatch.base import Dispatcher
-from httpx.multipart import _format_param, multipart_encode
+from httpx.utils import format_form_param
 
 
 class MockDispatch(Dispatcher):
@@ -20,7 +21,7 @@ class MockDispatch(Dispatcher):
         cert: CertTypes = None,
         timeout: TimeoutTypes = None,
     ) -> httpx.Response:
-        content = await request.read()
+        content = b"".join([part async for part in request.stream])
         return httpx.Response(200, content=content)
 
 
@@ -105,9 +106,9 @@ def test_multipart_encode():
 
     with mock.patch("os.urandom", return_value=os.urandom(16)):
         boundary = binascii.hexlify(os.urandom(16)).decode("ascii")
-        body, content_type = multipart_encode(data=data, files=files)
-        assert content_type == f"multipart/form-data; boundary={boundary}"
-        assert body == (
+        stream = encode(data=data, files=files)
+        assert stream.content_type == f"multipart/form-data; boundary={boundary}"
+        assert stream.body == (
             '--{0}\r\nContent-Disposition: form-data; name="a"\r\n\r\n1\r\n'
             '--{0}\r\nContent-Disposition: form-data; name="b"\r\n\r\nC\r\n'
             '--{0}\r\nContent-Disposition: form-data; name="c"\r\n\r\n11\r\n'
@@ -129,10 +130,10 @@ def test_multipart_encode_files_allows_filenames_as_none():
     with mock.patch("os.urandom", return_value=os.urandom(16)):
         boundary = binascii.hexlify(os.urandom(16)).decode("ascii")
 
-        body, content_type = multipart_encode(data={}, files=files)
+        stream = encode(data={}, files=files)
 
-        assert content_type == f"multipart/form-data; boundary={boundary}"
-        assert body == (
+        assert stream.content_type == f"multipart/form-data; boundary={boundary}"
+        assert stream.body == (
             '--{0}\r\nContent-Disposition: form-data; name="file"\r\n\r\n'
             "<file content>\r\n--{0}--\r\n"
             "".format(boundary).encode("ascii")
@@ -154,10 +155,10 @@ def test_multipart_encode_files_guesses_correct_content_type(
     with mock.patch("os.urandom", return_value=os.urandom(16)):
         boundary = binascii.hexlify(os.urandom(16)).decode("ascii")
 
-        body, content_type = multipart_encode(data={}, files=files)
+        stream = encode(data={}, files=files)
 
-        assert content_type == f"multipart/form-data; boundary={boundary}"
-        assert body == (
+        assert stream.content_type == f"multipart/form-data; boundary={boundary}"
+        assert stream.body == (
             f'--{boundary}\r\nContent-Disposition: form-data; name="file"; '
             f'filename="{file_name}"\r\nContent-Type: '
             f"{expected_content_type}\r\n\r\n<file content>\r\n--{boundary}--\r\n"
@@ -170,10 +171,10 @@ def test_multipart_encode_files_allows_str_content():
     with mock.patch("os.urandom", return_value=os.urandom(16)):
         boundary = binascii.hexlify(os.urandom(16)).decode("ascii")
 
-        body, content_type = multipart_encode(data={}, files=files)
+        stream = encode(data={}, files=files)
 
-        assert content_type == f"multipart/form-data; boundary={boundary}"
-        assert body == (
+        assert stream.content_type == f"multipart/form-data; boundary={boundary}"
+        assert stream.body == (
             '--{0}\r\nContent-Disposition: form-data; name="file"; '
             'filename="test.txt"\r\n'
             "Content-Type: text/plain\r\n\r\n<string content>\r\n"
@@ -184,17 +185,17 @@ def test_multipart_encode_files_allows_str_content():
 
 class TestHeaderParamHTML5Formatting:
     def test_unicode(self):
-        param = _format_param("filename", "n\u00e4me")
+        param = format_form_param("filename", "n\u00e4me")
         assert param == b'filename="n\xc3\xa4me"'
 
     def test_ascii(self):
-        param = _format_param("filename", b"name")
+        param = format_form_param("filename", b"name")
         assert param == b'filename="name"'
 
     def test_unicode_escape(self):
-        param = _format_param("filename", "hello\\world\u0022")
+        param = format_form_param("filename", "hello\\world\u0022")
         assert param == b'filename="hello\\\\world%22"'
 
     def test_unicode_with_control_character(self):
-        param = _format_param("filename", "hello\x1A\x1B\x1C")
+        param = format_form_param("filename", "hello\x1A\x1B\x1C")
         assert param == b'filename="hello%1A\x1B%1C"'