]> git.ipfire.org Git - thirdparty/httpx.git/commitdiff
Support Response(text=...), Response(html=...), Response(json=...) (#1297)
authorTom Christie <tom@tomchristie.com>
Mon, 21 Sep 2020 10:19:19 +0000 (11:19 +0100)
committerGitHub <noreply@github.com>
Mon, 21 Sep 2020 10:19:19 +0000 (11:19 +0100)
* Refactor content_streams internally

* Tidy up multipart

* Use ByteStream annotation internally

* Support Response(text=...), Response(html=...), Response(json=...)

* Add tests for Response(text=..., html=..., json=...)

httpx/_content.py
httpx/_models.py
tests/client/test_auth.py
tests/client/test_cookies.py
tests/client/test_headers.py
tests/client/test_queryparams.py
tests/client/test_redirects.py
tests/models/test_responses.py

index 8e5d5e95512e118ff4b2b896c811064564294867..bf402c9e299d7df07305efb55f3c91b03056e632 100644 (file)
@@ -1,6 +1,15 @@
 import inspect
-import typing
 from json import dumps as json_dumps
+from typing import (
+    Any,
+    AsyncIterable,
+    AsyncIterator,
+    Dict,
+    Iterable,
+    Iterator,
+    Tuple,
+    Union,
+)
 from urllib.parse import urlencode
 
 from ._exceptions import StreamConsumed
@@ -22,10 +31,10 @@ class PlainByteStream:
     def __init__(self, body: bytes) -> None:
         self._body = body
 
-    def __iter__(self) -> typing.Iterator[bytes]:
+    def __iter__(self) -> Iterator[bytes]:
         yield self._body
 
-    async def __aiter__(self) -> typing.AsyncIterator[bytes]:
+    async def __aiter__(self) -> AsyncIterator[bytes]:
         yield self._body
 
 
@@ -34,11 +43,11 @@ class GeneratorStream:
     Request content encoded as plain bytes, using an byte generator.
     """
 
-    def __init__(self, generator: typing.Iterable[bytes]) -> None:
+    def __init__(self, generator: Iterable[bytes]) -> None:
         self._generator = generator
         self._is_stream_consumed = False
 
-    def __iter__(self) -> typing.Iterator[bytes]:
+    def __iter__(self) -> Iterator[bytes]:
         if self._is_stream_consumed:
             raise StreamConsumed()
 
@@ -52,11 +61,11 @@ class AsyncGeneratorStream:
     Request content encoded as plain bytes, using an async byte iterator.
     """
 
-    def __init__(self, agenerator: typing.AsyncIterable[bytes]) -> None:
+    def __init__(self, agenerator: AsyncIterable[bytes]) -> None:
         self._agenerator = agenerator
         self._is_stream_consumed = False
 
-    async def __aiter__(self) -> typing.AsyncIterator[bytes]:
+    async def __aiter__(self) -> AsyncIterator[bytes]:
         if self._is_stream_consumed:
             raise StreamConsumed()
 
@@ -66,8 +75,8 @@ class AsyncGeneratorStream:
 
 
 def encode_content(
-    content: typing.Union[str, bytes, ByteStream]
-) -> typing.Tuple[typing.Dict[str, str], ByteStream]:
+    content: Union[str, bytes, ByteStream]
+) -> Tuple[Dict[str, str], ByteStream]:
     if isinstance(content, (str, bytes)):
         body = content.encode("utf-8") if isinstance(content, str) else content
         content_length = str(len(body))
@@ -75,7 +84,7 @@ def encode_content(
         stream = PlainByteStream(body)
         return headers, stream
 
-    elif isinstance(content, (typing.Iterable, typing.AsyncIterable)):
+    elif isinstance(content, (Iterable, AsyncIterable)):
         headers = {"Transfer-Encoding": "chunked"}
 
         # Generators should be wrapped in GeneratorStream/AsyncGeneratorStream
@@ -96,7 +105,7 @@ def encode_content(
 
 def encode_urlencoded_data(
     data: dict,
-) -> typing.Tuple[typing.Dict[str, str], ByteStream]:
+) -> Tuple[Dict[str, str], ByteStream]:
     body = urlencode(data, doseq=True).encode("utf-8")
     content_length = str(len(body))
     content_type = "application/x-www-form-urlencoded"
@@ -106,13 +115,29 @@ def encode_urlencoded_data(
 
 def encode_multipart_data(
     data: dict, files: RequestFiles, boundary: bytes = None
-) -> typing.Tuple[typing.Dict[str, str], ByteStream]:
+) -> Tuple[Dict[str, str], ByteStream]:
     stream = MultipartStream(data=data, files=files, boundary=boundary)
     headers = stream.get_headers()
     return headers, stream
 
 
-def encode_json(json: typing.Any) -> typing.Tuple[typing.Dict[str, str], ByteStream]:
+def encode_text(text: str) -> Tuple[Dict[str, str], ByteStream]:
+    body = text.encode("utf-8")
+    content_length = str(len(body))
+    content_type = "text/plain; charset=utf-8"
+    headers = {"Content-Length": content_length, "Content-Type": content_type}
+    return headers, PlainByteStream(body)
+
+
+def encode_html(html: str) -> Tuple[Dict[str, str], ByteStream]:
+    body = html.encode("utf-8")
+    content_length = str(len(body))
+    content_type = "text/html; charset=utf-8"
+    headers = {"Content-Length": content_length, "Content-Type": content_type}
+    return headers, PlainByteStream(body)
+
+
+def encode_json(json: Any) -> Tuple[Dict[str, str], ByteStream]:
     body = json_dumps(json).encode("utf-8")
     content_length = str(len(body))
     content_type = "application/json"
@@ -124,9 +149,9 @@ def encode_request(
     content: RequestContent = None,
     data: RequestData = None,
     files: RequestFiles = None,
-    json: typing.Any = None,
+    json: Any = None,
     boundary: bytes = None,
-) -> typing.Tuple[typing.Dict[str, str], ByteStream]:
+) -> Tuple[Dict[str, str], ByteStream]:
     """
     Handles encoding the given `content`, `data`, `files`, and `json`,
     returning a two-tuple of (<headers>, <stream>).
@@ -155,12 +180,21 @@ def encode_request(
 
 def encode_response(
     content: ResponseContent = None,
-) -> typing.Tuple[typing.Dict[str, str], ByteStream]:
+    text: str = None,
+    html: str = None,
+    json: Any = None,
+) -> Tuple[Dict[str, str], ByteStream]:
     """
     Handles encoding the given `content`, returning a two-tuple of
     (<headers>, <stream>).
     """
     if content is not None:
         return encode_content(content)
+    elif text is not None:
+        return encode_text(text)
+    elif html is not None:
+        return encode_html(html)
+    elif json is not None:
+        return encode_json(json)
 
     return {}, PlainByteStream(b"")
index 9dd3e0b0d4912e455ad28d8763bb12969eaa8096..03a08075f3d240e682e258c9ce43c676d79cd5b6 100644 (file)
@@ -704,11 +704,14 @@ class Response:
         self,
         status_code: int,
         *,
-        request: Request = None,
-        http_version: str = None,
         headers: HeaderTypes = None,
         content: ResponseContent = None,
+        text: str = None,
+        html: str = None,
+        json: typing.Any = None,
         stream: ByteStream = None,
+        http_version: str = None,
+        request: Request = None,
         history: typing.List["Response"] = None,
         on_close: typing.Callable = None,
     ):
@@ -740,7 +743,7 @@ class Response:
             # from the transport API.
             self.stream = stream
         else:
-            headers, stream = encode_response(content)
+            headers, stream = encode_response(content, text, html, json)
             self._prepare(headers)
             self.stream = stream
             if content is None or isinstance(content, bytes):
index e71fe906b0b15a6ddba0867c0b385d757f925257..59777b8c696ffc9a931cdcf660a3a66adfcb1734 100644 (file)
@@ -5,7 +5,6 @@ Unit tests for auth classes also exist in tests/test_auth.py
 """
 import asyncio
 import hashlib
-import json
 import os
 import threading
 import typing
@@ -27,8 +26,7 @@ class App:
     def __call__(self, request: httpx.Request) -> httpx.Response:
         headers = {"www-authenticate": self.auth_header} if self.auth_header else {}
         data = {"auth": request.headers.get("Authorization")}
-        content = json.dumps(data).encode("utf-8")
-        return httpx.Response(self.status_code, headers=headers, content=content)
+        return httpx.Response(self.status_code, headers=headers, json=data)
 
 
 class DigestApp:
@@ -50,8 +48,7 @@ class DigestApp:
             return self.challenge_send(request)
 
         data = {"auth": request.headers.get("Authorization")}
-        content = json.dumps(data).encode("utf-8")
-        return httpx.Response(200, content=content)
+        return httpx.Response(200, json=data)
 
     def challenge_send(self, request: httpx.Request) -> httpx.Response:
         self._response_count += 1
index af614effb69d79b8571ebd8542c13fcbf0a05a99..feb26ac4365087f863287fcaee214c77bc2925fb 100644 (file)
@@ -1,4 +1,3 @@
-import json
 from http.cookiejar import Cookie, CookieJar
 
 import httpx
@@ -8,8 +7,7 @@ from tests.utils import MockTransport
 def get_and_set_cookies(request: httpx.Request) -> httpx.Response:
     if request.url.path == "/echo_cookies":
         data = {"cookies": request.headers.get("cookie")}
-        content = json.dumps(data).encode("utf-8")
-        return httpx.Response(200, content=content)
+        return httpx.Response(200, json=data)
     elif request.url.path == "/set_cookie":
         return httpx.Response(200, headers={"set-cookie": "example-name=example-value"})
     else:
index d968616f4ec08e342f0a2086d90727601a15051d..556cd1df141d0e4de63b046b9bd46cc711cfb5de 100755 (executable)
@@ -1,7 +1,5 @@
 #!/usr/bin/env python3
 
-import json
-
 import pytest
 
 import httpx
@@ -10,8 +8,7 @@ from tests.utils import MockTransport
 
 def echo_headers(request: httpx.Request) -> httpx.Response:
     data = {"headers": dict(request.headers)}
-    content = json.dumps(data).encode("utf-8")
-    return httpx.Response(200, content=content)
+    return httpx.Response(200, json=data)
 
 
 def test_client_header():
index 39731d5bb0210c8d7302f404e750b2e9d03f0a55..6d3a9d5b5dc7875492734eeaf732626d23efe5a6 100644 (file)
@@ -3,7 +3,7 @@ from tests.utils import MockTransport
 
 
 def hello_world(request: httpx.Request) -> httpx.Response:
-    return httpx.Response(200, content=b"Hello, world")
+    return httpx.Response(200, text="Hello, world")
 
 
 def test_client_queryparams():
index 0d51717a0507a82a2f7929666a933ddd98131051..f32512bbf9a7f81ee6f6c6e17214653ee25b89c6 100644 (file)
@@ -1,5 +1,3 @@
-import json
-
 import httpcore
 import pytest
 
@@ -78,8 +76,11 @@ def redirects(request: httpx.Request) -> httpx.Response:
 
     elif request.url.path == "/cross_domain_target":
         status_code = httpx.codes.OK
-        content = json.dumps({"headers": dict(request.headers)}).encode("utf-8")
-        return httpx.Response(status_code, content=content)
+        data = {
+            "body": request.content.decode("ascii"),
+            "headers": dict(request.headers),
+        }
+        return httpx.Response(status_code, json=data)
 
     elif request.url.path == "/redirect_body":
         status_code = httpx.codes.PERMANENT_REDIRECT
@@ -92,10 +93,11 @@ def redirects(request: httpx.Request) -> httpx.Response:
         return httpx.Response(status_code, headers=headers)
 
     elif request.url.path == "/redirect_body_target":
-        content = json.dumps(
-            {"body": request.content.decode("ascii"), "headers": dict(request.headers)}
-        ).encode("utf-8")
-        return httpx.Response(200, content=content)
+        data = {
+            "body": request.content.decode("ascii"),
+            "headers": dict(request.headers),
+        }
+        return httpx.Response(200, json=data)
 
     elif request.url.path == "/cross_subdomain":
         if request.headers["Host"] != "www.example.org":
@@ -103,7 +105,7 @@ def redirects(request: httpx.Request) -> httpx.Response:
             headers = {"location": "https://www.example.org/cross_subdomain"}
             return httpx.Response(status_code, headers=headers)
         else:
-            return httpx.Response(200, content=b"Hello, world!")
+            return httpx.Response(200, text="Hello, world!")
 
     elif request.url.path == "/redirect_custom_scheme":
         status_code = httpx.codes.MOVED_PERMANENTLY
@@ -113,7 +115,7 @@ def redirects(request: httpx.Request) -> httpx.Response:
     if request.method == "HEAD":
         return httpx.Response(200)
 
-    return httpx.Response(200, content=b"Hello, world!")
+    return httpx.Response(200, html="<html><body>Hello, world!</body></html>")
 
 
 def test_no_redirect():
index e1568b324dd8f74105962bb3b8dd6ab0034441ad..2e38381185d532e4d436b59c835c92c96f190c05 100644 (file)
@@ -38,6 +38,48 @@ def test_response():
     assert not response.is_error
 
 
+def test_response_text():
+    response = httpx.Response(200, text="Hello, world!")
+
+    assert response.status_code == 200
+    assert response.reason_phrase == "OK"
+    assert response.text == "Hello, world!"
+    assert response.headers == httpx.Headers(
+        {
+            "Content-Length": "13",
+            "Content-Type": "text/plain; charset=utf-8",
+        }
+    )
+
+
+def test_response_html():
+    response = httpx.Response(200, html="<html><body>Hello, world!</html></body>")
+
+    assert response.status_code == 200
+    assert response.reason_phrase == "OK"
+    assert response.text == "<html><body>Hello, world!</html></body>"
+    assert response.headers == httpx.Headers(
+        {
+            "Content-Length": "39",
+            "Content-Type": "text/html; charset=utf-8",
+        }
+    )
+
+
+def test_response_json():
+    response = httpx.Response(200, json={"hello": "world"})
+
+    assert response.status_code == 200
+    assert response.reason_phrase == "OK"
+    assert response.json() == {"hello": "world"}
+    assert response.headers == httpx.Headers(
+        {
+            "Content-Length": "18",
+            "Content-Type": "application/json",
+        }
+    )
+
+
 def test_raise_for_status():
     request = httpx.Request("GET", "https://example.org")