]> git.ipfire.org Git - thirdparty/httpx.git/commitdiff
Added base class HTTPError with request/response attribute (#162)
authorhalbow <39669025+halbow@users.noreply.github.com>
Tue, 6 Aug 2019 13:20:48 +0000 (15:20 +0200)
committerSeth Michael Larson <sethmichaellarson@gmail.com>
Tue, 6 Aug 2019 13:20:48 +0000 (08:20 -0500)
httpx/client.py
httpx/concurrency.py
httpx/dispatch/http11.py
httpx/dispatch/http2.py
httpx/exceptions.py
httpx/interfaces.py
httpx/models.py
tests/client/test_async_client.py
tests/client/test_client.py
tests/dispatch/utils.py

index 5e814b38ee48a755ec3bce8becc76ff637bb9ffe..e22d1a9fb92cf2cdcb70e200ae9d17b2cf11371d 100644 (file)
@@ -24,6 +24,7 @@ from .exceptions import (
     RedirectBodyUnavailable,
     RedirectLoop,
     TooManyRedirects,
+    HTTPError,
 )
 from .interfaces import AsyncDispatcher, ConcurrencyBackend, Dispatcher
 from .models import (
@@ -82,13 +83,13 @@ class BaseClient:
                 )
 
         if dispatch is None:
-            async_dispatch = ConnectionPool(
+            async_dispatch: AsyncDispatcher = ConnectionPool(
                 verify=verify,
                 cert=cert,
                 timeout=timeout,
                 pool_limits=pool_limits,
                 backend=backend,
-            )  # type: AsyncDispatcher
+            )
         elif isinstance(dispatch, Dispatcher):
             async_dispatch = ThreadedDispatcher(dispatch, backend)
         else:
@@ -167,13 +168,18 @@ class BaseClient:
                 auth = HTTPBasicAuth(username=auth[0], password=auth[1])
             request = auth(request)
 
-        response = await self.send_handling_redirects(
-            request,
-            verify=verify,
-            cert=cert,
-            timeout=timeout,
-            allow_redirects=allow_redirects,
-        )
+        try:
+            response = await self.send_handling_redirects(
+                request,
+                verify=verify,
+                cert=cert,
+                timeout=timeout,
+                allow_redirects=allow_redirects,
+            )
+        except HTTPError as exc:
+            # Add the original request to any HTTPError
+            exc.request = request
+            raise
 
         if not stream:
             try:
@@ -200,19 +206,20 @@ class BaseClient:
             # We perform these checks here, so that calls to `response.next()`
             # will raise redirect errors if appropriate.
             if len(history) > self.max_redirects:
-                raise TooManyRedirects()
+                raise TooManyRedirects(response=history[-1])
             if request.url in [response.url for response in history]:
-                raise RedirectLoop()
+                raise RedirectLoop(response=history[-1])
 
             response = await self.dispatch.send(
                 request, verify=verify, cert=cert, timeout=timeout
             )
+
             should_close_response = True
             try:
                 assert isinstance(response, AsyncResponse)
                 response.history = list(history)
                 self.cookies.extract_cookies(response)
-                history = history + [response]
+                history.append(response)
 
                 if allow_redirects and response.is_redirect:
                     request = self.build_redirect_request(request, response)
@@ -249,7 +256,7 @@ class BaseClient:
         method = self.redirect_method(request, response)
         url = self.redirect_url(request, response)
         headers = self.redirect_headers(request, url)
-        content = self.redirect_content(request, method)
+        content = self.redirect_content(request, method, response)
         cookies = self.merge_cookies(request.cookies)
         return AsyncRequest(
             method=method, url=url, headers=headers, data=content, cookies=cookies
@@ -307,14 +314,16 @@ class BaseClient:
             del headers["Authorization"]
         return headers
 
-    def redirect_content(self, request: AsyncRequest, method: str) -> bytes:
+    def redirect_content(
+        self, request: AsyncRequest, method: str, response: AsyncResponse
+    ) -> bytes:
         """
         Return the body that should be used for the redirect request.
         """
         if method != request.method and method == "GET":
             return b""
         if request.is_streaming:
-            raise RedirectBodyUnavailable()
+            raise RedirectBodyUnavailable(response=response)
         return request.content
 
 
index 2ee45a85036ed20eac29e8f72def77045c99f3f1..f07eef4c64a6fe6ec3a91ff6d7d936a2a018fc90 100644 (file)
@@ -223,7 +223,7 @@ class AsyncioBackend(ConcurrencyBackend):
         writer = Writer(stream_writer=stream_writer, timeout=timeout)
         protocol = Protocol.HTTP_2 if ident == "h2" else Protocol.HTTP_11
 
-        return (reader, writer, protocol)
+        return reader, writer, protocol
 
     async def run_in_threadpool(
         self, func: typing.Callable, *args: typing.Any, **kwargs: typing.Any
index 690db4bb1004e7947e0a22c194b807b732cd5db4..0f34191e975c7d5d7197035991112aefb707c61a 100644 (file)
@@ -131,7 +131,7 @@ class HTTP11Connection:
                 assert isinstance(event, h11.Response)
                 break
         http_version = "HTTP/%s" % event.http_version.decode("latin-1", errors="ignore")
-        return (http_version, event.status_code, event.headers)
+        return http_version, event.status_code, event.headers
 
     async def _receive_response_data(
         self, timeout: TimeoutConfig = None
index 331f82df3808d5fceaaf0bac02f206dd9563f514..980b07b25cc83d674c3b0fe34b08c30962860efe 100644 (file)
@@ -133,7 +133,7 @@ class HTTP2Connection:
                 status_code = int(v.decode("ascii", errors="ignore"))
             elif not k.startswith(b":"):
                 headers.append((k, v))
-        return (status_code, headers)
+        return status_code, headers
 
     async def body_iter(
         self, stream_id: int, timeout: TimeoutConfig = None
index 19af3e6be46d0a6e807d95aefd1a8b26ba80c747..21dc7f4de4b0fd48528f1e46bc36682f946b36a4 100644 (file)
@@ -1,7 +1,24 @@
+from typing import TYPE_CHECKING
+
+if TYPE_CHECKING:
+    from .models import BaseRequest, BaseResponse  # pragma: nocover
+
+
+class HTTPError(Exception):
+    """
+    Base class for Httpx exception
+    """
+
+    def __init__(self, request: 'BaseRequest' = None, response: 'BaseResponse' = None, *args) -> None:
+        self.response = response
+        self.request = request or getattr(self.response, "request", None)
+        super().__init__(*args)
+
+
 # Timeout exceptions...
 
 
-class Timeout(Exception):
+class Timeout(HTTPError):
     """
     A base class for all timeouts.
     """
@@ -34,19 +51,13 @@ class PoolTimeout(Timeout):
 # HTTP exceptions...
 
 
-class HttpError(Exception):
-    """
-    An HTTP error occurred.
-    """
-
-
-class ProtocolError(Exception):
+class ProtocolError(HTTPError):
     """
     Malformed HTTP.
     """
 
 
-class DecodingError(Exception):
+class DecodingError(HTTPError):
     """
     Decoding of the response failed.
     """
@@ -55,7 +66,7 @@ class DecodingError(Exception):
 # Redirect exceptions...
 
 
-class RedirectError(Exception):
+class RedirectError(HTTPError):
     """
     Base class for HTTP redirect errors.
     """
@@ -83,7 +94,7 @@ class RedirectLoop(RedirectError):
 # Stream exceptions...
 
 
-class StreamException(Exception):
+class StreamError(HTTPError):
     """
     The base class for stream exceptions.
 
@@ -92,21 +103,21 @@ class StreamException(Exception):
     """
 
 
-class StreamConsumed(StreamException):
+class StreamConsumed(StreamError):
     """
     Attempted to read or stream response content, but the content has already
     been streamed.
     """
 
 
-class ResponseNotRead(StreamException):
+class ResponseNotRead(StreamError):
     """
     Attempted to access response content, without having called `read()`
     after a streaming response.
     """
 
 
-class ResponseClosed(StreamException):
+class ResponseClosed(StreamError):
     """
     Attempted to read or stream response content, but the request has been
     closed.
@@ -116,13 +127,13 @@ class ResponseClosed(StreamException):
 # Other cases...
 
 
-class InvalidURL(Exception):
+class InvalidURL(HTTPError):
     """
     URL was missing a hostname, or was not one of HTTP/HTTPS.
     """
 
 
-class CookieConflict(Exception):
+class CookieConflict(HTTPError):
     """
     Attempted to lookup a cookie by name, but multiple cookies existed.
     """
index f058edeb6dcb48c0228aee8b06fb928f1f7ccdf6..0a17f99c15dfcd295a96195c47639ec2b49c472a 100644 (file)
@@ -26,7 +26,7 @@ class AsyncDispatcher:
     """
     Base class for async dispatcher classes, that handle sending the request.
 
-    Stubs out the interface, as well as providing a `.request()` convienence
+    Stubs out the interface, as well as providing a `.request()` convenience
     implementation, to make it easy to use or test stand-alone dispatchers,
     without requiring a complete `Client` instance.
     """
@@ -72,9 +72,9 @@ class AsyncDispatcher:
 
 class Dispatcher:
     """
-    Base class for syncronous dispatcher classes, that handle sending the request.
+    Base class for synchronous dispatcher classes, that handle sending the request.
 
-    Stubs out the interface, as well as providing a `.request()` convienence
+    Stubs out the interface, as well as providing a `.request()` convenience
     implementation, to make it easy to use or test stand-alone dispatchers,
     without requiring a complete `Client` instance.
     """
@@ -136,7 +136,7 @@ class BaseReader:
 
 class BaseWriter:
     """
-    A stream writer. Abstracts away any asyncio-specfic interfaces
+    A stream writer. Abstracts away any asyncio-specific interfaces
     into a more generic base class, that we can use with alternate
     backend, or for stand-alone test cases.
     """
@@ -155,7 +155,7 @@ class BasePoolSemaphore:
     """
     A semaphore for use with connection pooling.
 
-    Abstracts away any asyncio-specfic interfaces.
+    Abstracts away any asyncio-specific interfaces.
     """
 
     async def acquire(self) -> None:
index 5e0c827ed5f78d7e33f4bbeb1eb43519bf9334e2..4b9bdc5d95b3d8744e889a9ecb0b1d8331022b6e 100644 (file)
@@ -20,7 +20,7 @@ from .decoders import (
 )
 from .exceptions import (
     CookieConflict,
-    HttpError,
+    HTTPError,
     InvalidURL,
     ResponseClosed,
     ResponseNotRead,
@@ -528,10 +528,10 @@ class BaseRequest:
         return content, content_type
 
     def prepare(self) -> None:
-        content = getattr(self, "content", None)  # type: bytes
+        content: typing.Optional[bytes] = getattr(self, "content", None)
         is_streaming = getattr(self, "is_streaming", False)
 
-        auto_headers = []  # type: typing.List[typing.Tuple[bytes, bytes]]
+        auto_headers: typing.List[typing.Tuple[bytes, bytes]] = []
 
         has_host = "host" in self.headers
         has_user_agent = "user-agent" in self.headers
@@ -687,7 +687,7 @@ class BaseResponse:
 
         self.request = request
         self.on_close = on_close
-        self.next = None  # typing.Optional[typing.Callable]
+        self.next: typing.Optional[typing.Callable] = None
 
     @property
     def reason_phrase(self) -> str:
@@ -776,7 +776,7 @@ class BaseResponse:
         content, depending on the Content-Encoding used in the response.
         """
         if not hasattr(self, "_decoder"):
-            decoders = []  # type: typing.List[Decoder]
+            decoders: typing.List[Decoder] = []
             values = self.headers.getlist("content-encoding", split_commas=True)
             for value in values:
                 value = value.strip().lower()
@@ -811,9 +811,8 @@ class BaseResponse:
             message = message.format(self, error_type="Server Error")
         else:
             message = ""
-
         if message:
-            raise HttpError(message)
+            raise HTTPError(message, response=self)
 
     def json(self, **kwargs: typing.Any) -> typing.Union[dict, list]:
         if self.charset_encoding is None and self.content and len(self.content) > 3:
index 4f79af0471c10e9aa12bde3909696aedd7079eae..b037085f25fc38629c3ea1b8b808890f941b7bce 100644 (file)
@@ -72,8 +72,9 @@ async def test_raise_for_status(server):
             )
 
             if 400 <= status_code < 600:
-                with pytest.raises(httpx.exceptions.HttpError):
+                with pytest.raises(httpx.exceptions.HTTPError) as exc_info:
                     response.raise_for_status()
+                assert exc_info.value.response == response
             else:
                 assert response.raise_for_status() is None
 
index f85fe77f48008569aa1da1e4523e66ba90a3f55a..97ae0277dd5d6f462f191a1686eb6ddbbb75a199 100644 (file)
@@ -95,10 +95,10 @@ def test_raise_for_status(server):
             response = client.request(
                 "GET", "http://127.0.0.1:8000/status/{}".format(status_code)
             )
-
             if 400 <= status_code < 600:
-                with pytest.raises(httpx.exceptions.HttpError):
+                with pytest.raises(httpx.exceptions.HTTPError) as exc_info:
                     response.raise_for_status()
+                assert exc_info.value.response == response
             else:
                 assert response.raise_for_status() is None
 
index 8a62554b2583f672ec3c5c5cfca4c4cd1a809c2d..c92fa7a310b32bdd3039364be03fdc07251f55d7 100644 (file)
@@ -29,7 +29,7 @@ class MockHTTP2Backend(AsyncioBackend):
         timeout: TimeoutConfig,
     ) -> typing.Tuple[BaseReader, BaseWriter, Protocol]:
         self.server = MockHTTP2Server(self.app)
-        return (self.server, self.server, Protocol.HTTP_2)
+        return self.server, self.server, Protocol.HTTP_2
 
 
 class MockHTTP2Server(BaseReader, BaseWriter):