]> git.ipfire.org Git - thirdparty/httpx.git/commitdiff
Make Request and Response picklable (#1579)
authorHannes Ljungberg <hannes@5monkeys.se>
Wed, 21 Apr 2021 10:11:00 +0000 (12:11 +0200)
committerGitHub <noreply@github.com>
Wed, 21 Apr 2021 10:11:00 +0000 (11:11 +0100)
* Make Request and Response picklable

* fixup! Make Request and Response picklable

* Apply suggestions from code review

* Apply suggestions from code review

* Update tests/models/test_requests.py

Co-authored-by: Tom Christie <tom@tomchristie.com>
httpx/_content.py
httpx/_models.py
tests/models/test_requests.py
tests/models/test_responses.py

index e4a906520b87ea9c345909c71f3974d2aa5d246a..9c7c1ff2252f0211bcf5f18ee6de65f6af4db0d1 100644 (file)
@@ -13,7 +13,7 @@ from typing import (
 )
 from urllib.parse import urlencode
 
-from ._exceptions import StreamConsumed
+from ._exceptions import StreamClosed, StreamConsumed
 from ._multipart import MultipartStream
 from ._transports.base import AsyncByteStream, SyncByteStream
 from ._types import RequestContent, RequestData, RequestFiles, ResponseContent
@@ -61,6 +61,21 @@ class AsyncIteratorByteStream(AsyncByteStream):
             yield part
 
 
+class UnattachedStream(AsyncByteStream, SyncByteStream):
+    """
+    If a request or response is serialized using pickle, then it is no longer
+    attached to a stream for I/O purposes. Any stream operations should result
+    in `httpx.StreamClosed`.
+    """
+
+    def __iter__(self) -> Iterator[bytes]:
+        raise StreamClosed()
+
+    async def __aiter__(self) -> AsyncIterator[bytes]:
+        raise StreamClosed()
+        yield b""  # pragma: nocover
+
+
 def encode_content(
     content: Union[str, bytes, Iterable[bytes], AsyncIterable[bytes]]
 ) -> Tuple[Dict[str, str], Union[SyncByteStream, AsyncByteStream]]:
index 2e4a3b6c8ad679ed80402a9ed21f7b627604d8d3..357baaca154bb6f23f61c22cae572c1bbc498594 100644 (file)
@@ -11,7 +11,7 @@ from urllib.parse import parse_qsl, quote, unquote, urlencode
 import rfc3986
 import rfc3986.exceptions
 
-from ._content import ByteStream, encode_request, encode_response
+from ._content import ByteStream, UnattachedStream, encode_request, encode_response
 from ._decoders import (
     SUPPORTED_DECODERS,
     ByteChunker,
@@ -898,6 +898,18 @@ class Request:
         url = str(self.url)
         return f"<{class_name}({self.method!r}, {url!r})>"
 
+    def __getstate__(self) -> typing.Dict[str, typing.Any]:
+        return {
+            name: value
+            for name, value in self.__dict__.items()
+            if name not in ["stream"]
+        }
+
+    def __setstate__(self, state: typing.Dict[str, typing.Any]) -> None:
+        for name, value in state.items():
+            setattr(self, name, value)
+        self.stream = UnattachedStream()
+
 
 class Response:
     def __init__(
@@ -1156,6 +1168,19 @@ class Response:
     def __repr__(self) -> str:
         return f"<Response [{self.status_code} {self.reason_phrase}]>"
 
+    def __getstate__(self) -> typing.Dict[str, typing.Any]:
+        return {
+            name: value
+            for name, value in self.__dict__.items()
+            if name not in ["stream", "is_closed", "_decoder"]
+        }
+
+    def __setstate__(self, state: typing.Dict[str, typing.Any]) -> None:
+        for name, value in state.items():
+            setattr(self, name, value)
+        self.is_closed = True
+        self.stream = UnattachedStream()
+
     def read(self) -> bytes:
         """
         Read and return the response content.
index cfc53e0b593751f04dfe4eadcab77045ed8a0cca..a93e8994584d64454223797d569f83f05868c0a8 100644 (file)
@@ -1,3 +1,4 @@
+import pickle
 import typing
 
 import pytest
@@ -174,3 +175,54 @@ def test_url():
     assert request.url.port is None
     assert request.url.path == "/abc"
     assert request.url.raw_path == b"/abc?foo=bar"
+
+
+def test_request_picklable():
+    request = httpx.Request("POST", "http://example.org", json={"test": 123})
+    pickle_request = pickle.loads(pickle.dumps(request))
+    assert pickle_request.method == "POST"
+    assert pickle_request.url.path == "/"
+    assert pickle_request.headers["Content-Type"] == "application/json"
+    assert pickle_request.content == b'{"test": 123}'
+    assert pickle_request.stream is not None
+    assert request.headers == {
+        "Host": "example.org",
+        "Content-Type": "application/json",
+        "content-length": "13",
+    }
+
+
+@pytest.mark.asyncio
+async def test_request_async_streaming_content_picklable():
+    async def streaming_body(data):
+        yield data
+
+    data = streaming_body(b"test 123")
+    request = httpx.Request("POST", "http://example.org", content=data)
+    pickle_request = pickle.loads(pickle.dumps(request))
+    with pytest.raises(httpx.RequestNotRead):
+        pickle_request.content
+    with pytest.raises(httpx.StreamClosed):
+        await pickle_request.aread()
+
+    request = httpx.Request("POST", "http://example.org", content=data)
+    await request.aread()
+    pickle_request = pickle.loads(pickle.dumps(request))
+    assert pickle_request.content == b"test 123"
+
+
+def test_request_generator_content_picklable():
+    def content():
+        yield b"test 123"  # pragma: nocover
+
+    request = httpx.Request("POST", "http://example.org", content=content())
+    pickle_request = pickle.loads(pickle.dumps(request))
+    with pytest.raises(httpx.RequestNotRead):
+        pickle_request.content
+    with pytest.raises(httpx.StreamClosed):
+        pickle_request.read()
+
+    request = httpx.Request("POST", "http://example.org", content=content())
+    request.read()
+    pickle_request = pickle.loads(pickle.dumps(request))
+    assert pickle_request.content == b"test 123"
index 78f5db0b705c96f52592086e8ec43d2a367d1291..5e2afc1bf300c9b7e6ae29af16c4a8b9d9cca611 100644 (file)
@@ -1,4 +1,5 @@
 import json
+import pickle
 from unittest import mock
 
 import brotli
@@ -853,3 +854,41 @@ def test_generator_with_content_length_header():
     headers = {"Content-Length": "8"}
     response = httpx.Response(200, content=content(), headers=headers)
     assert response.headers == {"Content-Length": "8"}
+
+
+def test_response_picklable():
+    response = httpx.Response(
+        200,
+        content=b"Hello, world!",
+        request=httpx.Request("GET", "https://example.org"),
+    )
+    pickle_response = pickle.loads(pickle.dumps(response))
+    assert pickle_response.is_closed is True
+    assert pickle_response.is_stream_consumed is True
+    assert pickle_response.next_request is None
+    assert pickle_response.stream is not None
+    assert pickle_response.content == b"Hello, world!"
+    assert pickle_response.status_code == 200
+    assert pickle_response.request.url == response.request.url
+    assert pickle_response.extensions == {}
+    assert pickle_response.history == []
+
+
+@pytest.mark.asyncio
+async def test_response_async_streaming_picklable():
+    response = httpx.Response(200, content=async_streaming_body())
+    pickle_response = pickle.loads(pickle.dumps(response))
+    with pytest.raises(httpx.ResponseNotRead):
+        pickle_response.content
+    with pytest.raises(httpx.StreamClosed):
+        await pickle_response.aread()
+    assert pickle_response.is_stream_consumed is False
+    assert pickle_response.num_bytes_downloaded == 0
+    assert pickle_response.headers == {"Transfer-Encoding": "chunked"}
+
+    response = httpx.Response(200, content=async_streaming_body())
+    await response.aread()
+    pickle_response = pickle.loads(pickle.dumps(response))
+    assert pickle_response.is_stream_consumed is True
+    assert pickle_response.content == b"Hello, world!"
+    assert pickle_response.num_bytes_downloaded == 13