]> git.ipfire.org Git - thirdparty/httpx.git/commitdiff
Make the response's request parameter optional (#1238)
authortbascoul <tbascoul+github@gmail.com>
Tue, 1 Sep 2020 14:14:57 +0000 (16:14 +0200)
committerGitHub <noreply@github.com>
Tue, 1 Sep 2020 14:14:57 +0000 (15:14 +0100)
* Make the response's request parameter optional

* Fix _models coverage

* Move DecodingError in _models

* Update httpx/_models.py

* Update _models.py

* Update test_responses.py

* Update test_responses.py

Co-authored-by: Tom Christie <tom@tomchristie.com>
httpx/_decoders.py
httpx/_models.py
tests/models/test_responses.py
tests/test_decoders.py

index fa9f8124a234f48def72691ca0b072a72c7cc28f..6c0c4492f97fed5f3464773bad5d1f7a4a2ff60e 100644 (file)
@@ -9,21 +9,13 @@ import zlib
 
 import chardet
 
-from ._exceptions import DecodingError
-
 try:
     import brotli
 except ImportError:  # pragma: nocover
     brotli = None
 
-if typing.TYPE_CHECKING:  # pragma: no cover
-    from ._models import Request
-
 
 class Decoder:
-    def __init__(self, request: "Request") -> None:
-        self.request = request
-
     def decode(self, data: bytes) -> bytes:
         raise NotImplementedError()  # pragma: nocover
 
@@ -50,8 +42,7 @@ class DeflateDecoder(Decoder):
     See: https://stackoverflow.com/questions/1838699
     """
 
-    def __init__(self, request: "Request") -> None:
-        self.request = request
+    def __init__(self) -> None:
         self.first_attempt = True
         self.decompressor = zlib.decompressobj()
 
@@ -64,13 +55,13 @@ class DeflateDecoder(Decoder):
             if was_first_attempt:
                 self.decompressor = zlib.decompressobj(-zlib.MAX_WBITS)
                 return self.decode(data)
-            raise DecodingError(message=str(exc), request=self.request)
+            raise ValueError(str(exc))
 
     def flush(self) -> bytes:
         try:
             return self.decompressor.flush()
         except zlib.error as exc:  # pragma: nocover
-            raise DecodingError(message=str(exc), request=self.request)
+            raise ValueError(str(exc))
 
 
 class GZipDecoder(Decoder):
@@ -80,21 +71,20 @@ class GZipDecoder(Decoder):
     See: https://stackoverflow.com/questions/1838699
     """
 
-    def __init__(self, request: "Request") -> None:
-        self.request = request
+    def __init__(self) -> None:
         self.decompressor = zlib.decompressobj(zlib.MAX_WBITS | 16)
 
     def decode(self, data: bytes) -> bytes:
         try:
             return self.decompressor.decompress(data)
         except zlib.error as exc:
-            raise DecodingError(message=str(exc), request=self.request)
+            raise ValueError(str(exc))
 
     def flush(self) -> bytes:
         try:
             return self.decompressor.flush()
         except zlib.error as exc:  # pragma: nocover
-            raise DecodingError(message=str(exc), request=self.request)
+            raise ValueError(str(exc))
 
 
 class BrotliDecoder(Decoder):
@@ -107,7 +97,7 @@ class BrotliDecoder(Decoder):
     name. The top branches are for 'brotlipy' and bottom branches for 'Brotli'
     """
 
-    def __init__(self, request: "Request") -> None:
+    def __init__(self) -> None:
         if brotli is None:  # pragma: nocover
             raise ImportError(
                 "Using 'BrotliDecoder', but the 'brotlipy' or 'brotli' library "
@@ -115,7 +105,6 @@ class BrotliDecoder(Decoder):
                 "Make sure to install httpx using `pip install httpx[brotli]`."
             ) from None
 
-        self.request = request
         self.decompressor = brotli.Decompressor()
         self.seen_data = False
         if hasattr(self.decompressor, "decompress"):
@@ -130,7 +119,7 @@ class BrotliDecoder(Decoder):
         try:
             return self._decompress(data)
         except brotli.error as exc:
-            raise DecodingError(message=str(exc), request=self.request)
+            raise ValueError(str(exc))
 
     def flush(self) -> bytes:
         if not self.seen_data:
@@ -140,7 +129,7 @@ class BrotliDecoder(Decoder):
                 self.decompressor.finish()
             return b""
         except brotli.error as exc:  # pragma: nocover
-            raise DecodingError(message=str(exc), request=self.request)
+            raise ValueError(str(exc))
 
 
 class MultiDecoder(Decoder):
@@ -173,8 +162,7 @@ class TextDecoder:
     Handles incrementally decoding bytes into text
     """
 
-    def __init__(self, request: "Request", encoding: typing.Optional[str] = None):
-        self.request = request
+    def __init__(self, encoding: typing.Optional[str] = None):
         self.decoder: typing.Optional[codecs.IncrementalDecoder] = (
             None if encoding is None else codecs.getincrementaldecoder(encoding)()
         )
@@ -209,7 +197,7 @@ class TextDecoder:
 
             return text
         except UnicodeDecodeError as exc:  # pragma: nocover
-            raise DecodingError(message=str(exc), request=self.request)
+            raise ValueError(str(exc))
 
     def flush(self) -> str:
         try:
@@ -222,14 +210,13 @@ class TextDecoder:
 
             return self.decoder.decode(b"", True)
         except UnicodeDecodeError as exc:  # pragma: nocover
-            raise DecodingError(message=str(exc), request=self.request)
+            raise ValueError(str(exc))
 
     def _detector_result(self) -> str:
         self.detector.close()
         result = self.detector.result["encoding"]
         if not result:  # pragma: nocover
-            message = "Unable to determine encoding of content"
-            raise DecodingError(message, request=self.request)
+            raise ValueError("Unable to determine encoding of content")
 
         return result
 
index 67cebeb09115ba68eef5cc51471ba74991755197..4a4026326646f0e187e6e0a56951c42ea251588a 100644 (file)
@@ -1,4 +1,5 @@
 import cgi
+import contextlib
 import datetime
 import email.message
 import json as jsonlib
@@ -26,6 +27,7 @@ from ._decoders import (
 from ._exceptions import (
     HTTPCORE_EXC_MAP,
     CookieConflict,
+    DecodingError,
     HTTPStatusError,
     InvalidURL,
     NotRedirectResponse,
@@ -689,7 +691,7 @@ class Response:
         self,
         status_code: int,
         *,
-        request: Request,
+        request: Request = None,
         http_version: str = None,
         headers: HeaderTypes = None,
         stream: ContentStream = None,
@@ -700,7 +702,8 @@ class Response:
         self.http_version = http_version
         self.headers = Headers(headers)
 
-        self.request = request
+        self._request: typing.Optional[Request] = request
+
         self.call_next: typing.Optional[typing.Callable] = None
 
         self.history = [] if history is None else list(history)
@@ -726,6 +729,21 @@ class Response:
             )
         return self._elapsed
 
+    @property
+    def request(self) -> Request:
+        """
+        Returns the request instance associated to the current response.
+        """
+        if self._request is None:
+            raise RuntimeError(
+                "The request instance has not been set on this response."
+            )
+        return self._request
+
+    @request.setter
+    def request(self, value: Request) -> None:
+        self._request = value
+
     @property
     def reason_phrase(self) -> str:
         return codes.get_reason_phrase(self.status_code)
@@ -811,7 +829,7 @@ class Response:
                 value = value.strip().lower()
                 try:
                     decoder_cls = SUPPORTED_DECODERS[value]
-                    decoders.append(decoder_cls(request=self.request))
+                    decoders.append(decoder_cls())
                 except KeyError:
                     continue
 
@@ -820,7 +838,7 @@ class Response:
             elif len(decoders) > 1:
                 self._decoder = MultiDecoder(children=decoders)
             else:
-                self._decoder = IdentityDecoder(request=self.request)
+                self._decoder = IdentityDecoder()
 
         return self._decoder
 
@@ -841,12 +859,19 @@ class Response:
             "For more information check: https://httpstatuses.com/{0.status_code}"
         )
 
+        request = self._request
+        if request is None:
+            raise RuntimeError(
+                "Cannot call `raise_for_status` as the request "
+                "instance has not been set on this response."
+            )
+
         if codes.is_client_error(self.status_code):
             message = message.format(self, error_type="Client Error")
-            raise HTTPStatusError(message, request=self.request, response=self)
+            raise HTTPStatusError(message, request=request, response=self)
         elif codes.is_server_error(self.status_code):
             message = message.format(self, error_type="Server Error")
-            raise HTTPStatusError(message, request=self.request, response=self)
+            raise HTTPStatusError(message, request=request, response=self)
 
     def json(self, **kwargs: typing.Any) -> typing.Any:
         if self.charset_encoding is None and self.content and len(self.content) > 3:
@@ -882,6 +907,17 @@ class Response:
     def __repr__(self) -> str:
         return f"<Response [{self.status_code} {self.reason_phrase}]>"
 
+    @contextlib.contextmanager
+    def _wrap_decoder_errors(self) -> typing.Iterator[None]:
+        # If the response has an associated request instance, we want decoding
+        # errors to be raised as proper `httpx.DecodingError` exceptions.
+        try:
+            yield
+        except ValueError as exc:
+            if self._request is None:
+                raise exc
+            raise DecodingError(message=str(exc), request=self.request) from exc
+
     def read(self) -> bytes:
         """
         Read and return the response content.
@@ -898,9 +934,10 @@ class Response:
         if hasattr(self, "_content"):
             yield self._content
         else:
-            for chunk in self.iter_raw():
-                yield self.decoder.decode(chunk)
-            yield self.decoder.flush()
+            with self._wrap_decoder_errors():
+                for chunk in self.iter_raw():
+                    yield self.decoder.decode(chunk)
+                yield self.decoder.flush()
 
     def iter_text(self) -> typing.Iterator[str]:
         """
@@ -908,18 +945,20 @@ class Response:
         that handles both gzip, deflate, etc but also detects the content's
         string encoding.
         """
-        decoder = TextDecoder(request=self.request, encoding=self.charset_encoding)
-        for chunk in self.iter_bytes():
-            yield decoder.decode(chunk)
-        yield decoder.flush()
+        decoder = TextDecoder(encoding=self.charset_encoding)
+        with self._wrap_decoder_errors():
+            for chunk in self.iter_bytes():
+                yield decoder.decode(chunk)
+            yield decoder.flush()
 
     def iter_lines(self) -> typing.Iterator[str]:
         decoder = LineDecoder()
-        for text in self.iter_text():
-            for line in decoder.decode(text):
+        with self._wrap_decoder_errors():
+            for text in self.iter_text():
+                for line in decoder.decode(text):
+                    yield line
+            for line in decoder.flush():
                 yield line
-        for line in decoder.flush():
-            yield line
 
     def iter_raw(self) -> typing.Iterator[bytes]:
         """
@@ -931,7 +970,7 @@ class Response:
             raise ResponseClosed()
 
         self.is_stream_consumed = True
-        with map_exceptions(HTTPCORE_EXC_MAP, request=self.request):
+        with map_exceptions(HTTPCORE_EXC_MAP, request=self._request):
             for part in self._raw_stream:
                 yield part
         self.close()
@@ -956,7 +995,8 @@ class Response:
         """
         if not self.is_closed:
             self.is_closed = True
-            self._elapsed = self.request.timer.elapsed
+            if self._request is not None:
+                self._elapsed = self.request.timer.elapsed
             self._raw_stream.close()
 
     async def aread(self) -> bytes:
@@ -975,9 +1015,10 @@ class Response:
         if hasattr(self, "_content"):
             yield self._content
         else:
-            async for chunk in self.aiter_raw():
-                yield self.decoder.decode(chunk)
-            yield self.decoder.flush()
+            with self._wrap_decoder_errors():
+                async for chunk in self.aiter_raw():
+                    yield self.decoder.decode(chunk)
+                yield self.decoder.flush()
 
     async def aiter_text(self) -> typing.AsyncIterator[str]:
         """
@@ -985,18 +1026,20 @@ class Response:
         that handles both gzip, deflate, etc but also detects the content's
         string encoding.
         """
-        decoder = TextDecoder(request=self.request, encoding=self.charset_encoding)
-        async for chunk in self.aiter_bytes():
-            yield decoder.decode(chunk)
-        yield decoder.flush()
+        decoder = TextDecoder(encoding=self.charset_encoding)
+        with self._wrap_decoder_errors():
+            async for chunk in self.aiter_bytes():
+                yield decoder.decode(chunk)
+            yield decoder.flush()
 
     async def aiter_lines(self) -> typing.AsyncIterator[str]:
         decoder = LineDecoder()
-        async for text in self.aiter_text():
-            for line in decoder.decode(text):
+        with self._wrap_decoder_errors():
+            async for text in self.aiter_text():
+                for line in decoder.decode(text):
+                    yield line
+            for line in decoder.flush():
                 yield line
-        for line in decoder.flush():
-            yield line
 
     async def aiter_raw(self) -> typing.AsyncIterator[bytes]:
         """
@@ -1008,7 +1051,7 @@ class Response:
             raise ResponseClosed()
 
         self.is_stream_consumed = True
-        with map_exceptions(HTTPCORE_EXC_MAP, request=self.request):
+        with map_exceptions(HTTPCORE_EXC_MAP, request=self._request):
             async for part in self._raw_stream:
                 yield part
         await self.aclose()
@@ -1032,7 +1075,8 @@ class Response:
         """
         if not self.is_closed:
             self.is_closed = True
-            self._elapsed = self.request.timer.elapsed
+            if self._request is not None:
+                self._elapsed = self.request.timer.elapsed
             await self._raw_stream.aclose()
 
 
index 878bef072d49da7f1c92bafe94d9a2649c7772d4..e9fbeca22d0770f02f6cec6d0698d538c28728fd 100644 (file)
@@ -2,6 +2,7 @@ import datetime
 import json
 from unittest import mock
 
+import brotli
 import pytest
 
 import httpx
@@ -31,6 +32,28 @@ def test_response():
     assert not response.is_error
 
 
+def test_raise_for_status():
+    # 2xx status codes are not an error.
+    response = httpx.Response(200, request=REQUEST)
+    response.raise_for_status()
+
+    # 4xx status codes are a client error.
+    response = httpx.Response(403, request=REQUEST)
+    with pytest.raises(httpx.HTTPStatusError):
+        response.raise_for_status()
+
+    # 5xx status codes are a server error.
+    response = httpx.Response(500, request=REQUEST)
+    with pytest.raises(httpx.HTTPStatusError):
+        response.raise_for_status()
+
+    # Calling .raise_for_status without setting a request instance is
+    # not valid. Should raise a runtime error.
+    response = httpx.Response(200)
+    with pytest.raises(RuntimeError):
+        response.raise_for_status()
+
+
 def test_response_repr():
     response = httpx.Response(200, content=b"Hello, world!", request=REQUEST)
     assert repr(response) == "<Response [200 OK]>"
@@ -372,7 +395,18 @@ def test_json_without_specified_encoding_decode_error():
         response = httpx.Response(
             200, content=content, headers=headers, request=REQUEST
         )
-        with pytest.raises(json.JSONDecodeError):
+        with pytest.raises(json.decoder.JSONDecodeError):
+            response.json()
+
+
+def test_json_without_specified_encoding_value_error():
+    data = {"greeting": "hello", "recipient": "world"}
+    content = json.dumps(data).encode("utf-32-be")
+    headers = {"Content-Type": "application/json"}
+    # force incorrect guess from `guess_json_utf` to trigger error
+    with mock.patch("httpx._models.guess_json_utf", return_value="utf-32"):
+        response = httpx.Response(200, content=content, headers=headers)
+        with pytest.raises(ValueError):
             response.json()
 
 
@@ -395,3 +429,45 @@ def test_json_without_specified_encoding_decode_error():
 def test_link_headers(headers, expected):
     response = httpx.Response(200, content=None, headers=headers, request=REQUEST)
     assert response.links == expected
+
+
+@pytest.mark.parametrize("header_value", (b"deflate", b"gzip", b"br"))
+def test_decode_error_with_request(header_value):
+    headers = [(b"Content-Encoding", header_value)]
+    body = b"test 123"
+    compressed_body = brotli.compress(body)[3:]
+    with pytest.raises(httpx.DecodingError):
+        httpx.Response(200, headers=headers, content=compressed_body, request=REQUEST)
+
+
+@pytest.mark.parametrize("header_value", (b"deflate", b"gzip", b"br"))
+def test_value_error_without_request(header_value):
+    headers = [(b"Content-Encoding", header_value)]
+    body = b"test 123"
+    compressed_body = brotli.compress(body)[3:]
+    with pytest.raises(ValueError):
+        httpx.Response(200, headers=headers, content=compressed_body)
+
+
+def test_response_with_unset_request():
+    response = httpx.Response(200, content=b"Hello, world!")
+
+    assert response.status_code == 200
+    assert response.reason_phrase == "OK"
+    assert response.text == "Hello, world!"
+    assert not response.is_error
+
+
+def test_set_request_after_init():
+    response = httpx.Response(200, content=b"Hello, world!")
+
+    response.request = REQUEST
+
+    assert response.request == REQUEST
+
+
+def test_cannot_access_unset_request():
+    response = httpx.Response(200, content=b"Hello, world!")
+
+    with pytest.raises(RuntimeError):
+        assert response.request is not None
index ec01d41e4cf025739974746eeef84d0b9ba84349..abd478e400277f9ab300a08c3da040d32b669b79 100644 (file)
@@ -135,9 +135,9 @@ def test_empty_content(header_value):
     "decoder", (BrotliDecoder, DeflateDecoder, GZipDecoder, IdentityDecoder)
 )
 def test_decoders_empty_cases(decoder):
-    request = httpx.Request(method="GET", url="https://www.example.com")
-    instance = decoder(request)
-    assert instance.decode(b"") == b""
+    response = httpx.Response(content=b"", status_code=200)
+    instance = decoder()
+    assert instance.decode(response.content) == b""
     assert instance.flush() == b""
 
 
@@ -207,12 +207,10 @@ async def test_text_decoder_known_encoding():
 
 
 def test_text_decoder_empty_cases():
-    request = httpx.Request(method="GET", url="https://www.example.com")
-
-    decoder = TextDecoder(request=request)
+    decoder = TextDecoder()
     assert decoder.flush() == ""
 
-    decoder = TextDecoder(request=request)
+    decoder = TextDecoder()
     assert decoder.decode(b"") == ""
     assert decoder.flush() == ""