]> git.ipfire.org Git - thirdparty/httpx.git/commitdiff
Transport API (#1522)
authorTom Christie <tom@tomchristie.com>
Wed, 24 Mar 2021 12:36:34 +0000 (12:36 +0000)
committerGitHub <noreply@github.com>
Wed, 24 Mar 2021 12:36:34 +0000 (12:36 +0000)
* Added httpx.BaseTransport and httpx.AsyncBaseTransport

* Test coverage and default transports to calling .close on __exit__

* BaseTransport documentation

* Use 'handle_request' for the transport API.

* Docs tweaks

* Docs tweaks

* Minor docstring tweak

* Transport API docs

* Drop 'Optional' on Transport API

* Docs tweaks

* Tweak CHANGELOG

* Drop erronous example.py

* Push httpcore exception wrapping out of client into transport (#1524)

* Push httpcore exception wrapping out of client into transport

* Include close/aclose extensions in docstring

* Comment about the request property on RequestError exceptions

* Extensions reason_phrase and http_version as bytes (#1526)

* Extensions reason_phrase and http_version as bytes

* Update BaseTransport docstring

* Neaten up our try...except structure for ensuring responses (#1525)

* Fix CHANGELOG typo

Co-authored-by: Florimond Manca <florimond.manca@gmail.com>
* Fix CHANGELOG typo

Co-authored-by: Florimond Manca <florimond.manca@gmail.com>
* stream: Iterator[bytes] -> stream: Iterable[bytes]

* Use proper bytestream interfaces when calling into httpcore

* Grungy typing workaround due to httpcore using Iterator instead of Iterable in bytestream types

* Update docs/advanced.md

Co-authored-by: Florimond Manca <florimond.manca@gmail.com>
* Consistent typing imports across tranports

* Update docs/advanced.md

Co-authored-by: Florimond Manca <florimond.manca@gmail.com>
Co-authored-by: Florimond Manca <florimond.manca@gmail.com>
20 files changed:
CHANGELOG.md
docs/advanced.md
httpx/__init__.py
httpx/_client.py
httpx/_decoders.py
httpx/_exceptions.py
httpx/_models.py
httpx/_transports/asgi.py
httpx/_transports/base.py [new file with mode: 0644]
httpx/_transports/default.py
httpx/_transports/mock.py
httpx/_transports/wsgi.py
tests/client/test_async_client.py
tests/client/test_client.py
tests/client/test_redirects.py
tests/conftest.py
tests/models/test_responses.py
tests/test_asgi.py
tests/test_decoders.py
tests/test_exceptions.py

index 992f4e4e386a7c436d02bd19364cc272e9f6d910..fa4d61112adee29a7d7fbeb33cfb1385d4e7fcef 100644 (file)
@@ -4,14 +4,31 @@ All notable changes to this project will be documented in this file.
 
 The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).
 
-## 0.17.1
+## Master
+
+The 0.18.x release series formalises our low-level Transport API, introducing the
+base classes `httpx.BaseTransport` and `httpx.AsyncBaseTransport`.
+
+See the "Writing custom transports" documentation and the `httpx.BaseTransport.handle_request()`
+docstring for more complete details on implementing custom transports.
+
+Pull request #1522 includes a checklist of differences from the previous `httpcore` transport API,
+for developers implementing custom transports.
+
+### Changed
+
+* Transport instances now inherit from `httpx.BaseTransport` or `httpx.AsyncBaseTransport`,
+  and should implement either the `handle_request` method or `handle_async_request` method.
+* The `response.ext` property and `Response(ext=...)` argument are now named `extensions`.
+
+## 0.17.1 (March 15th, 2021)
 
 ### Fixed
 
 * Type annotation on `CertTypes` allows `keyfile` and `password` to be optional. (Pull #1503)
 * Fix httpcore pinned version. (Pull #1495)
 
-## 0.17.0
+## 0.17.0 (February 28th, 2021)
 
 ### Added
 
index 61bf4c1938a3a633e1498d9fc4c5e8d8dc6b6778..0b31b47855f30c31e19d191f701d50a70767523b 100644 (file)
@@ -1015,31 +1015,39 @@ This [public gist](https://gist.github.com/florimondmanca/d56764d78d748eb9f73165
 
 ### Writing custom transports
 
-A transport instance must implement the Transport API defined by
-[`httpcore`](https://www.encode.io/httpcore/api/). You
-should either subclass `httpcore.AsyncHTTPTransport` to implement a transport to
-use with `AsyncClient`, or subclass `httpcore.SyncHTTPTransport` to implement a
-transport to use with `Client`.
+A transport instance must implement the low-level Transport API, which deals
+with sending a single request, and returning a response. You should either
+subclass `httpx.BaseTransport` to implement a transport to use with `Client`,
+or subclass `httpx.AsyncBaseTransport` to implement a transport to
+use with `AsyncClient`.
+
+At the layer of the transport API we're using plain primitives.
+No `Request` or `Response` models, no fancy `URL` or `Header` handling.
+This strict point of cut-off provides a clear design separation between the
+HTTPX API, and the low-level network handling.
+
+See the `handle_request` and `handle_async_request` docstrings for more details
+on the specifics of the Transport API.
 
 A complete example of a custom transport implementation would be:
 
 ```python
 import json
-import httpcore
+import httpx
 
 
-class HelloWorldTransport(httpcore.SyncHTTPTransport):
+class HelloWorldTransport(httpx.BaseTransport):
     """
     A mock transport that always returns a JSON "Hello, world!" response.
     """
 
-    def request(self, method, url, headers=None, stream=None, ext=None):
+    def handle_request(self, method, url, headers, stream, extensions):
         message = {"text": "Hello, world!"}
         content = json.dumps(message).encode("utf-8")
-        stream = httpcore.PlainByteStream(content)
+        stream = [content]
         headers = [(b"content-type", b"application/json")]
-        ext = {"http_version": b"HTTP/1.1"}
-        return 200, headers, stream, ext
+        extensions = {}
+        return 200, headers, stream, extensions
 ```
 
 Which we can use in the same way:
@@ -1084,24 +1092,23 @@ which transport an outgoing request should be routed via, with [the same style
 used for specifying proxy routing](#routing).
 
 ```python
-import httpcore
 import httpx
 
-class HTTPSRedirectTransport(httpcore.SyncHTTPTransport):
+class HTTPSRedirectTransport(httpx.BaseTransport):
     """
     A transport that always redirects to HTTPS.
     """
 
-    def request(self, method, url, headers=None, stream=None, ext=None):
+    def handle_request(self, method, url, headers, stream, extensions):
         scheme, host, port, path = url
         if port is None:
             location = b"https://%s%s" % (host, path)
         else:
             location = b"https://%s:%d%s" % (host, port, path)
-        stream = httpcore.PlainByteStream(b"")
+        stream = [b""]
         headers = [(b"location", location)]
-        ext = {"http_version": b"HTTP/1.1"}
-        return 303, headers, stream, ext
+        extensions = {}
+        return 303, headers, stream, extensions
 
 
 # A client where any `http` requests are always redirected to `https`
index 96d9e0c2f8ea27271cba35f5429fde5ea65ab75e..a441669bf6a75fd9038609571b9225f92656cdea 100644 (file)
@@ -36,6 +36,7 @@ from ._exceptions import (
 from ._models import URL, Cookies, Headers, QueryParams, Request, Response
 from ._status_codes import StatusCode, codes
 from ._transports.asgi import ASGITransport
+from ._transports.base import AsyncBaseTransport, BaseTransport
 from ._transports.default import AsyncHTTPTransport, HTTPTransport
 from ._transports.mock import MockTransport
 from ._transports.wsgi import WSGITransport
@@ -45,9 +46,11 @@ __all__ = [
     "__title__",
     "__version__",
     "ASGITransport",
+    "AsyncBaseTransport",
     "AsyncClient",
     "AsyncHTTPTransport",
     "Auth",
+    "BaseTransport",
     "BasicAuth",
     "Client",
     "CloseError",
index da38a14346bcc079daccb171702b295fd1593166..691111ba134d689d77112141d3931b30f15e5795 100644 (file)
@@ -4,8 +4,6 @@ import typing
 import warnings
 from types import TracebackType
 
-import httpcore
-
 from .__version__ import __version__
 from ._auth import Auth, BasicAuth, FunctionAuth
 from ._config import (
@@ -20,15 +18,15 @@ from ._config import (
 )
 from ._decoders import SUPPORTED_DECODERS
 from ._exceptions import (
-    HTTPCORE_EXC_MAP,
     InvalidURL,
     RemoteProtocolError,
     TooManyRedirects,
-    map_exceptions,
+    request_context,
 )
 from ._models import URL, Cookies, Headers, QueryParams, Request, Response
 from ._status_codes import codes
 from ._transports.asgi import ASGITransport
+from ._transports.base import AsyncBaseTransport, BaseTransport
 from ._transports.default import AsyncHTTPTransport, HTTPTransport
 from ._transports.wsgi import WSGITransport
 from ._types import (
@@ -569,14 +567,14 @@ class Client(BaseClient):
         cert: CertTypes = None,
         http2: bool = False,
         proxies: ProxiesTypes = None,
-        mounts: typing.Mapping[str, httpcore.SyncHTTPTransport] = None,
+        mounts: typing.Mapping[str, BaseTransport] = None,
         timeout: TimeoutTypes = DEFAULT_TIMEOUT_CONFIG,
         limits: Limits = DEFAULT_LIMITS,
         pool_limits: Limits = None,
         max_redirects: int = DEFAULT_MAX_REDIRECTS,
         event_hooks: typing.Mapping[str, typing.List[typing.Callable]] = None,
         base_url: URLTypes = "",
-        transport: httpcore.SyncHTTPTransport = None,
+        transport: BaseTransport = None,
         app: typing.Callable = None,
         trust_env: bool = True,
     ):
@@ -620,9 +618,7 @@ class Client(BaseClient):
             app=app,
             trust_env=trust_env,
         )
-        self._mounts: typing.Dict[
-            URLPattern, typing.Optional[httpcore.SyncHTTPTransport]
-        ] = {
+        self._mounts: typing.Dict[URLPattern, typing.Optional[BaseTransport]] = {
             URLPattern(key): None
             if proxy is None
             else self._init_proxy_transport(
@@ -648,10 +644,10 @@ class Client(BaseClient):
         cert: CertTypes = None,
         http2: bool = False,
         limits: Limits = DEFAULT_LIMITS,
-        transport: httpcore.SyncHTTPTransport = None,
+        transport: BaseTransport = None,
         app: typing.Callable = None,
         trust_env: bool = True,
-    ) -> httpcore.SyncHTTPTransport:
+    ) -> BaseTransport:
         if transport is not None:
             return transport
 
@@ -670,7 +666,7 @@ class Client(BaseClient):
         http2: bool = False,
         limits: Limits = DEFAULT_LIMITS,
         trust_env: bool = True,
-    ) -> httpcore.SyncHTTPTransport:
+    ) -> BaseTransport:
         return HTTPTransport(
             verify=verify,
             cert=cert,
@@ -680,7 +676,7 @@ class Client(BaseClient):
             proxy=proxy,
         )
 
-    def _transport_for_url(self, url: URL) -> httpcore.SyncHTTPTransport:
+    def _transport_for_url(self, url: URL) -> BaseTransport:
         """
         Returns the transport instance that should be used for a given URL.
         This will either be the standard connection pool, or a proxy.
@@ -775,21 +771,18 @@ class Client(BaseClient):
             allow_redirects=allow_redirects,
             history=[],
         )
-
-        if not stream:
-            try:
+        try:
+            if not stream:
                 response.read()
-            finally:
-                response.close()
 
-        try:
             for hook in self._event_hooks["response"]:
                 hook(response)
-        except Exception:
-            response.close()
-            raise
 
-        return response
+            return response
+
+        except Exception as exc:
+            response.close()
+            raise exc
 
     def _send_handling_auth(
         self,
@@ -813,18 +806,20 @@ class Client(BaseClient):
                 history=history,
             )
             try:
-                next_request = auth_flow.send(response)
-            except StopIteration:
-                return response
-            except BaseException as exc:
-                response.close()
-                raise exc from None
-            else:
+                try:
+                    next_request = auth_flow.send(response)
+                except StopIteration:
+                    return response
+
                 response.history = list(history)
                 response.read()
                 request = next_request
                 history.append(response)
 
+            except Exception as exc:
+                response.close()
+                raise exc
+
     def _send_handling_redirects(
         self,
         request: Request,
@@ -839,19 +834,24 @@ class Client(BaseClient):
                 )
 
             response = self._send_single_request(request, timeout)
-            response.history = list(history)
+            try:
+                response.history = list(history)
 
-            if not response.is_redirect:
-                return response
+                if not response.is_redirect:
+                    return response
 
-            if allow_redirects:
-                response.read()
-            request = self._build_redirect_request(request, response)
-            history = history + [response]
+                request = self._build_redirect_request(request, response)
+                history = history + [response]
 
-            if not allow_redirects:
-                response.next_request = request
-                return response
+                if allow_redirects:
+                    response.read()
+                else:
+                    response.next_request = request
+                    return response
+
+            except Exception as exc:
+                response.close()
+                raise exc
 
     def _send_single_request(self, request: Request, timeout: Timeout) -> Response:
         """
@@ -861,25 +861,25 @@ class Client(BaseClient):
         timer = Timer()
         timer.sync_start()
 
-        with map_exceptions(HTTPCORE_EXC_MAP, request=request):
-            (status_code, headers, stream, ext) = transport.request(
+        with request_context(request=request):
+            (status_code, headers, stream, extensions) = transport.handle_request(
                 request.method.encode(),
                 request.url.raw,
                 headers=request.headers.raw,
                 stream=request.stream,  # type: ignore
-                ext={"timeout": timeout.as_dict()},
+                extensions={"timeout": timeout.as_dict()},
             )
 
         def on_close(response: Response) -> None:
             response.elapsed = datetime.timedelta(seconds=timer.sync_elapsed())
-            if hasattr(stream, "close"):
-                stream.close()
+            if "close" in extensions:
+                extensions["close"]()
 
         response = Response(
             status_code,
             headers=headers,
-            stream=stream,  # type: ignore
-            ext=ext,
+            stream=stream,
+            extensions=extensions,
             request=request,
             on_close=on_close,
         )
@@ -1202,14 +1202,14 @@ class AsyncClient(BaseClient):
         cert: CertTypes = None,
         http2: bool = False,
         proxies: ProxiesTypes = None,
-        mounts: typing.Mapping[str, httpcore.AsyncHTTPTransport] = None,
+        mounts: typing.Mapping[str, AsyncBaseTransport] = None,
         timeout: TimeoutTypes = DEFAULT_TIMEOUT_CONFIG,
         limits: Limits = DEFAULT_LIMITS,
         pool_limits: Limits = None,
         max_redirects: int = DEFAULT_MAX_REDIRECTS,
         event_hooks: typing.Mapping[str, typing.List[typing.Callable]] = None,
         base_url: URLTypes = "",
-        transport: httpcore.AsyncHTTPTransport = None,
+        transport: AsyncBaseTransport = None,
         app: typing.Callable = None,
         trust_env: bool = True,
     ):
@@ -1254,9 +1254,7 @@ class AsyncClient(BaseClient):
             trust_env=trust_env,
         )
 
-        self._mounts: typing.Dict[
-            URLPattern, typing.Optional[httpcore.AsyncHTTPTransport]
-        ] = {
+        self._mounts: typing.Dict[URLPattern, typing.Optional[AsyncBaseTransport]] = {
             URLPattern(key): None
             if proxy is None
             else self._init_proxy_transport(
@@ -1281,10 +1279,10 @@ class AsyncClient(BaseClient):
         cert: CertTypes = None,
         http2: bool = False,
         limits: Limits = DEFAULT_LIMITS,
-        transport: httpcore.AsyncHTTPTransport = None,
+        transport: AsyncBaseTransport = None,
         app: typing.Callable = None,
         trust_env: bool = True,
-    ) -> httpcore.AsyncHTTPTransport:
+    ) -> AsyncBaseTransport:
         if transport is not None:
             return transport
 
@@ -1303,7 +1301,7 @@ class AsyncClient(BaseClient):
         http2: bool = False,
         limits: Limits = DEFAULT_LIMITS,
         trust_env: bool = True,
-    ) -> httpcore.AsyncHTTPTransport:
+    ) -> AsyncBaseTransport:
         return AsyncHTTPTransport(
             verify=verify,
             cert=cert,
@@ -1313,7 +1311,7 @@ class AsyncClient(BaseClient):
             proxy=proxy,
         )
 
-    def _transport_for_url(self, url: URL) -> httpcore.AsyncHTTPTransport:
+    def _transport_for_url(self, url: URL) -> AsyncBaseTransport:
         """
         Returns the transport instance that should be used for a given URL.
         This will either be the standard connection pool, or a proxy.
@@ -1409,21 +1407,18 @@ class AsyncClient(BaseClient):
             allow_redirects=allow_redirects,
             history=[],
         )
-
-        if not stream:
-            try:
+        try:
+            if not stream:
                 await response.aread()
-            finally:
-                await response.aclose()
 
-        try:
             for hook in self._event_hooks["response"]:
                 await hook(response)
-        except Exception:
-            await response.aclose()
-            raise
 
-        return response
+            return response
+
+        except Exception as exc:
+            await response.aclose()
+            raise exc
 
     async def _send_handling_auth(
         self,
@@ -1447,18 +1442,20 @@ class AsyncClient(BaseClient):
                 history=history,
             )
             try:
-                next_request = await auth_flow.asend(response)
-            except StopAsyncIteration:
-                return response
-            except BaseException as exc:
-                await response.aclose()
-                raise exc from None
-            else:
+                try:
+                    next_request = await auth_flow.asend(response)
+                except StopAsyncIteration:
+                    return response
+
                 response.history = list(history)
                 await response.aread()
                 request = next_request
                 history.append(response)
 
+            except Exception as exc:
+                await response.aclose()
+                raise exc
+
     async def _send_handling_redirects(
         self,
         request: Request,
@@ -1473,19 +1470,24 @@ class AsyncClient(BaseClient):
                 )
 
             response = await self._send_single_request(request, timeout)
-            response.history = list(history)
+            try:
+                response.history = list(history)
 
-            if not response.is_redirect:
-                return response
+                if not response.is_redirect:
+                    return response
 
-            if allow_redirects:
-                await response.aread()
-            request = self._build_redirect_request(request, response)
-            history = history + [response]
+                request = self._build_redirect_request(request, response)
+                history = history + [response]
 
-            if not allow_redirects:
-                response.next_request = request
-                return response
+                if allow_redirects:
+                    await response.aread()
+                else:
+                    response.next_request = request
+                    return response
+
+            except Exception as exc:
+                await response.aclose()
+                raise exc
 
     async def _send_single_request(
         self, request: Request, timeout: Timeout
@@ -1497,26 +1499,30 @@ class AsyncClient(BaseClient):
         timer = Timer()
         await timer.async_start()
 
-        with map_exceptions(HTTPCORE_EXC_MAP, request=request):
-            (status_code, headers, stream, ext) = await transport.arequest(
+        with request_context(request=request):
+            (
+                status_code,
+                headers,
+                stream,
+                extensions,
+            ) = await transport.handle_async_request(
                 request.method.encode(),
                 request.url.raw,
                 headers=request.headers.raw,
                 stream=request.stream,  # type: ignore
-                ext={"timeout": timeout.as_dict()},
+                extensions={"timeout": timeout.as_dict()},
             )
 
         async def on_close(response: Response) -> None:
             response.elapsed = datetime.timedelta(seconds=await timer.async_elapsed())
-            if hasattr(stream, "aclose"):
-                with map_exceptions(HTTPCORE_EXC_MAP, request=request):
-                    await stream.aclose()
+            if "aclose" in extensions:
+                await extensions["aclose"]()
 
         response = Response(
             status_code,
             headers=headers,
-            stream=stream,  # type: ignore
-            ext=ext,
+            stream=stream,
+            extensions=extensions,
             request=request,
             on_close=on_close,
         )
index 8ef0157e6f5d14afdb0fc1a1c472bb405d41146c..c0d51a4cdc8ed099423e3d763cde56d15d0269c9 100644 (file)
@@ -8,6 +8,8 @@ import io
 import typing
 import zlib
 
+from ._exceptions import DecodingError
+
 try:
     import brotli
 except ImportError:  # pragma: nocover
@@ -54,13 +56,13 @@ class DeflateDecoder(ContentDecoder):
             if was_first_attempt:
                 self.decompressor = zlib.decompressobj(-zlib.MAX_WBITS)
                 return self.decode(data)
-            raise ValueError(str(exc))
+            raise DecodingError(str(exc)) from exc
 
     def flush(self) -> bytes:
         try:
             return self.decompressor.flush()
         except zlib.error as exc:  # pragma: nocover
-            raise ValueError(str(exc))
+            raise DecodingError(str(exc)) from exc
 
 
 class GZipDecoder(ContentDecoder):
@@ -77,13 +79,13 @@ class GZipDecoder(ContentDecoder):
         try:
             return self.decompressor.decompress(data)
         except zlib.error as exc:
-            raise ValueError(str(exc))
+            raise DecodingError(str(exc)) from exc
 
     def flush(self) -> bytes:
         try:
             return self.decompressor.flush()
         except zlib.error as exc:  # pragma: nocover
-            raise ValueError(str(exc))
+            raise DecodingError(str(exc)) from exc
 
 
 class BrotliDecoder(ContentDecoder):
@@ -118,7 +120,7 @@ class BrotliDecoder(ContentDecoder):
         try:
             return self._decompress(data)
         except brotli.error as exc:
-            raise ValueError(str(exc))
+            raise DecodingError(str(exc)) from exc
 
     def flush(self) -> bytes:
         if not self.seen_data:
@@ -128,7 +130,7 @@ class BrotliDecoder(ContentDecoder):
                 self.decompressor.finish()
             return b""
         except brotli.error as exc:  # pragma: nocover
-            raise ValueError(str(exc))
+            raise DecodingError(str(exc)) from exc
 
 
 class MultiDecoder(ContentDecoder):
index bade9f9b8167e72df54b7fca1101694b824e882b..092dbcf04eb2ebeeb01481980d0d368df64c8824 100644 (file)
@@ -34,8 +34,6 @@ Our exception hierarchy:
 import contextlib
 import typing
 
-import httpcore
-
 if typing.TYPE_CHECKING:
     from ._models import Request, Response  # pragma: nocover
 
@@ -58,9 +56,8 @@ class HTTPError(Exception):
     ```
     """
 
-    def __init__(self, message: str, *, request: "Request") -> None:
+    def __init__(self, message: str) -> None:
         super().__init__(message)
-        self.request = request
 
 
 class RequestError(HTTPError):
@@ -68,15 +65,30 @@ class RequestError(HTTPError):
     Base class for all exceptions that may occur when issuing a `.request()`.
     """
 
-    def __init__(self, message: str, *, request: "Request") -> None:
-        super().__init__(message, request=request)
+    def __init__(self, message: str, *, request: "Request" = None) -> None:
+        super().__init__(message)
+        # At the point an exception is raised we won't typically have a request
+        # instance to associate it with.
+        #
+        # The 'request_context' context manager is used within the Client and
+        # Response methods in order to ensure that any raised exceptions
+        # have a `.request` property set on them.
+        self._request = request
+
+    @property
+    def request(self) -> "Request":
+        if self._request is None:
+            raise RuntimeError("The .request property has not been set.")
+        return self._request
+
+    @request.setter
+    def request(self, request: "Request") -> None:
+        self._request = request
 
 
 class TransportError(RequestError):
     """
     Base class for all exceptions that occur at the level of the Transport API.
-
-    All of these exceptions also have an equivelent mapping in `httpcore`.
     """
 
 
@@ -219,7 +231,8 @@ class HTTPStatusError(HTTPError):
     def __init__(
         self, message: str, *, request: "Request", response: "Response"
     ) -> None:
-        super().__init__(message, request=request)
+        super().__init__(message)
+        self.request = request
         self.response = response
 
 
@@ -318,45 +331,14 @@ class ResponseClosed(StreamError):
 
 
 @contextlib.contextmanager
-def map_exceptions(
-    mapping: typing.Mapping[typing.Type[Exception], typing.Type[Exception]],
-    **kwargs: typing.Any,
-) -> typing.Iterator[None]:
+def request_context(request: "Request" = None) -> typing.Iterator[None]:
+    """
+    A context manager that can be used to attach the given request context
+    to any `RequestError` exceptions that are raised within the block.
+    """
     try:
         yield
-    except Exception as exc:
-        mapped_exc = None
-
-        for from_exc, to_exc in mapping.items():
-            if not isinstance(exc, from_exc):
-                continue
-            # We want to map to the most specific exception we can find.
-            # Eg if `exc` is an `httpcore.ReadTimeout`, we want to map to
-            # `httpx.ReadTimeout`, not just `httpx.TimeoutException`.
-            if mapped_exc is None or issubclass(to_exc, mapped_exc):
-                mapped_exc = to_exc
-
-        if mapped_exc is None:
-            raise
-
-        message = str(exc)
-        raise mapped_exc(message, **kwargs) from exc  # type: ignore
-
-
-HTTPCORE_EXC_MAP = {
-    httpcore.TimeoutException: TimeoutException,
-    httpcore.ConnectTimeout: ConnectTimeout,
-    httpcore.ReadTimeout: ReadTimeout,
-    httpcore.WriteTimeout: WriteTimeout,
-    httpcore.PoolTimeout: PoolTimeout,
-    httpcore.NetworkError: NetworkError,
-    httpcore.ConnectError: ConnectError,
-    httpcore.ReadError: ReadError,
-    httpcore.WriteError: WriteError,
-    httpcore.CloseError: CloseError,
-    httpcore.ProxyError: ProxyError,
-    httpcore.UnsupportedProtocol: UnsupportedProtocol,
-    httpcore.ProtocolError: ProtocolError,
-    httpcore.LocalProtocolError: LocalProtocolError,
-    httpcore.RemoteProtocolError: RemoteProtocolError,
-}
+    except RequestError as exc:
+        if request is not None:
+            exc.request = request
+        raise exc
index 83deb9a243530f0a78038805b8143a249cd45bbb..34fb2d388c448eb760231e4109bb44b6e1d8a761 100644 (file)
@@ -1,5 +1,4 @@
 import cgi
-import contextlib
 import datetime
 import email.message
 import json as jsonlib
@@ -24,16 +23,14 @@ from ._decoders import (
     TextDecoder,
 )
 from ._exceptions import (
-    HTTPCORE_EXC_MAP,
     CookieConflict,
-    DecodingError,
     HTTPStatusError,
     InvalidURL,
     RequestNotRead,
     ResponseClosed,
     ResponseNotRead,
     StreamConsumed,
-    map_exceptions,
+    request_context,
 )
 from ._status_codes import codes
 from ._types import (
@@ -909,7 +906,7 @@ class Response:
         json: typing.Any = None,
         stream: ByteStream = None,
         request: Request = None,
-        ext: dict = None,
+        extensions: dict = None,
         history: typing.List["Response"] = None,
         on_close: typing.Callable = None,
     ):
@@ -924,7 +921,7 @@ class Response:
 
         self.call_next: typing.Optional[typing.Callable] = None
 
-        self.ext = {} if ext is None else ext
+        self.extensions = {} if extensions is None else extensions
         self.history = [] if history is None else list(history)
         self._on_close = on_close
 
@@ -995,11 +992,17 @@ class Response:
 
     @property
     def http_version(self) -> str:
-        return self.ext.get("http_version", "HTTP/1.1")
+        try:
+            return self.extensions["http_version"].decode("ascii", errors="ignore")
+        except KeyError:
+            return "HTTP/1.1"
 
     @property
     def reason_phrase(self) -> str:
-        return self.ext.get("reason", codes.get_reason_phrase(self.status_code))
+        try:
+            return self.extensions["reason_phrase"].decode("ascii", errors="ignore")
+        except KeyError:
+            return codes.get_reason_phrase(self.status_code)
 
     @property
     def url(self) -> typing.Optional[URL]:
@@ -1152,17 +1155,6 @@ class Response:
     def __repr__(self) -> str:
         return f"<Response [{self.status_code} {self.reason_phrase}]>"
 
-    @contextlib.contextmanager
-    def _wrap_decoder_errors(self) -> typing.Iterator[None]:
-        # If the response has an associated request instance, we want decoding
-        # errors to be raised as proper `httpx.DecodingError` exceptions.
-        try:
-            yield
-        except ValueError as exc:
-            if self._request is None:
-                raise exc
-            raise DecodingError(message=str(exc), request=self.request) from exc
-
     def read(self) -> bytes:
         """
         Read and return the response content.
@@ -1183,7 +1175,7 @@ class Response:
         else:
             decoder = self._get_content_decoder()
             chunker = ByteChunker(chunk_size=chunk_size)
-            with self._wrap_decoder_errors():
+            with request_context(request=self._request):
                 for raw_bytes in self.iter_raw():
                     decoded = decoder.decode(raw_bytes)
                     for chunk in chunker.decode(decoded):
@@ -1202,7 +1194,7 @@ class Response:
         """
         decoder = TextDecoder(encoding=self.encoding)
         chunker = TextChunker(chunk_size=chunk_size)
-        with self._wrap_decoder_errors():
+        with request_context(request=self._request):
             for byte_content in self.iter_bytes():
                 text_content = decoder.decode(byte_content)
                 for chunk in chunker.decode(text_content):
@@ -1215,7 +1207,7 @@ class Response:
 
     def iter_lines(self) -> typing.Iterator[str]:
         decoder = LineDecoder()
-        with self._wrap_decoder_errors():
+        with request_context(request=self._request):
             for text in self.iter_text():
                 for line in decoder.decode(text):
                     yield line
@@ -1237,7 +1229,7 @@ class Response:
         self._num_bytes_downloaded = 0
         chunker = ByteChunker(chunk_size=chunk_size)
 
-        with map_exceptions(HTTPCORE_EXC_MAP, request=self._request):
+        with request_context(request=self._request):
             for raw_stream_bytes in self.stream:
                 self._num_bytes_downloaded += len(raw_stream_bytes)
                 for chunk in chunker.decode(raw_stream_bytes):
@@ -1256,7 +1248,8 @@ class Response:
         if not self.is_closed:
             self.is_closed = True
             if self._on_close is not None:
-                self._on_close(self)
+                with request_context(request=self._request):
+                    self._on_close(self)
 
     async def aread(self) -> bytes:
         """
@@ -1278,7 +1271,7 @@ class Response:
         else:
             decoder = self._get_content_decoder()
             chunker = ByteChunker(chunk_size=chunk_size)
-            with self._wrap_decoder_errors():
+            with request_context(request=self._request):
                 async for raw_bytes in self.aiter_raw():
                     decoded = decoder.decode(raw_bytes)
                     for chunk in chunker.decode(decoded):
@@ -1297,7 +1290,7 @@ class Response:
         """
         decoder = TextDecoder(encoding=self.encoding)
         chunker = TextChunker(chunk_size=chunk_size)
-        with self._wrap_decoder_errors():
+        with request_context(request=self._request):
             async for byte_content in self.aiter_bytes():
                 text_content = decoder.decode(byte_content)
                 for chunk in chunker.decode(text_content):
@@ -1310,7 +1303,7 @@ class Response:
 
     async def aiter_lines(self) -> typing.AsyncIterator[str]:
         decoder = LineDecoder()
-        with self._wrap_decoder_errors():
+        with request_context(request=self._request):
             async for text in self.aiter_text():
                 for line in decoder.decode(text):
                     yield line
@@ -1332,7 +1325,7 @@ class Response:
         self._num_bytes_downloaded = 0
         chunker = ByteChunker(chunk_size=chunk_size)
 
-        with map_exceptions(HTTPCORE_EXC_MAP, request=self._request):
+        with request_context(request=self._request):
             async for raw_stream_bytes in self.stream:
                 self._num_bytes_downloaded += len(raw_stream_bytes)
                 for chunk in chunker.decode(raw_stream_bytes):
@@ -1351,7 +1344,8 @@ class Response:
         if not self.is_closed:
             self.is_closed = True
             if self._on_close is not None:
-                await self._on_close(self)
+                with request_context(request=self._request):
+                    await self._on_close(self)
 
 
 class Cookies(MutableMapping):
index 758d8375b2165189662df29a37844a348facda89..ef0a3ef29ab43843d9847f9b35accd6df8ef6219 100644 (file)
@@ -1,15 +1,16 @@
-from typing import TYPE_CHECKING, Callable, List, Optional, Tuple, Union
+import typing
 from urllib.parse import unquote
 
-import httpcore
 import sniffio
 
-if TYPE_CHECKING:  # pragma: no cover
+from .base import AsyncBaseTransport
+
+if typing.TYPE_CHECKING:  # pragma: no cover
     import asyncio
 
     import trio
 
-    Event = Union[asyncio.Event, trio.Event]
+    Event = typing.Union[asyncio.Event, trio.Event]
 
 
 def create_event() -> "Event":
@@ -23,7 +24,7 @@ def create_event() -> "Event":
         return asyncio.Event()
 
 
-class ASGITransport(httpcore.AsyncHTTPTransport):
+class ASGITransport(AsyncBaseTransport):
     """
     A custom AsyncTransport that handles sending requests directly to an ASGI app.
     The simplest way to use this functionality is to use the `app` argument.
@@ -58,27 +59,26 @@ class ASGITransport(httpcore.AsyncHTTPTransport):
 
     def __init__(
         self,
-        app: Callable,
+        app: typing.Callable,
         raise_app_exceptions: bool = True,
         root_path: str = "",
-        client: Tuple[str, int] = ("127.0.0.1", 123),
+        client: typing.Tuple[str, int] = ("127.0.0.1", 123),
     ) -> None:
         self.app = app
         self.raise_app_exceptions = raise_app_exceptions
         self.root_path = root_path
         self.client = client
 
-    async def arequest(
+    async def handle_async_request(
         self,
         method: bytes,
-        url: Tuple[bytes, bytes, Optional[int], bytes],
-        headers: List[Tuple[bytes, bytes]] = None,
-        stream: httpcore.AsyncByteStream = None,
-        ext: dict = None,
-    ) -> Tuple[int, List[Tuple[bytes, bytes]], httpcore.AsyncByteStream, dict]:
-        headers = [] if headers is None else headers
-        stream = httpcore.PlainByteStream(content=b"") if stream is None else stream
-
+        url: typing.Tuple[bytes, bytes, typing.Optional[int], bytes],
+        headers: typing.List[typing.Tuple[bytes, bytes]],
+        stream: typing.AsyncIterable[bytes],
+        extensions: dict,
+    ) -> typing.Tuple[
+        int, typing.List[typing.Tuple[bytes, bytes]], typing.AsyncIterable[bytes], dict
+    ]:
         # ASGI scope.
         scheme, host, port, full_path = url
         path, _, query = full_path.partition(b"?")
@@ -155,7 +155,9 @@ class ASGITransport(httpcore.AsyncHTTPTransport):
         assert status_code is not None
         assert response_headers is not None
 
-        stream = httpcore.PlainByteStream(content=b"".join(body_parts))
-        ext = {}
+        async def response_stream() -> typing.AsyncIterator[bytes]:
+            yield b"".join(body_parts)
+
+        extensions = {}
 
-        return (status_code, response_headers, stream, ext)
+        return (status_code, response_headers, response_stream(), extensions)
diff --git a/httpx/_transports/base.py b/httpx/_transports/base.py
new file mode 100644 (file)
index 0000000..e26938f
--- /dev/null
@@ -0,0 +1,129 @@
+import typing
+from types import TracebackType
+
+T = typing.TypeVar("T", bound="BaseTransport")
+A = typing.TypeVar("A", bound="AsyncBaseTransport")
+
+
+class BaseTransport:
+    def __enter__(self: T) -> T:
+        return self
+
+    def __exit__(
+        self,
+        exc_type: typing.Type[BaseException] = None,
+        exc_value: BaseException = None,
+        traceback: TracebackType = None,
+    ) -> None:
+        self.close()
+
+    def handle_request(
+        self,
+        method: bytes,
+        url: typing.Tuple[bytes, bytes, typing.Optional[int], bytes],
+        headers: typing.List[typing.Tuple[bytes, bytes]],
+        stream: typing.Iterable[bytes],
+        extensions: dict,
+    ) -> typing.Tuple[
+        int, typing.List[typing.Tuple[bytes, bytes]], typing.Iterable[bytes], dict
+    ]:
+        """
+        Send a single HTTP request and return a response.
+
+        At this layer of API we're simply using plain primitives. No `Request` or
+        `Response` models, no fancy `URL` or `Header` handling. This strict point
+        of cut-off provides a clear design seperation between the HTTPX API,
+        and the low-level network handling.
+
+        Developers shouldn't typically ever need to call into this API directly,
+        since the Client class provides all the higher level user-facing API
+        niceties.
+
+        Example usage:
+
+            with httpx.HTTPTransport() as transport:
+                status_code, headers, stream, extensions = transport.handle_request(
+                    method=b'GET',
+                    url=(b'https', b'www.example.com', 443, b'/'),
+                    headers=[(b'Host', b'www.example.com')],
+                    stream=[],
+                    extensions={}
+                )
+                try:
+                    body = b''.join([part for part in stream])
+                finally:
+                    if 'close' in extensions:
+                        extensions['close']()
+                print(status_code, headers, body)
+
+        Arguments:
+
+        method: The request method as bytes. Eg. b'GET'.
+        url: The components of the request URL, as a tuple of `(scheme, host, port, target)`.
+             The target will usually be the URL path, but also allows for alternative
+             formulations, such as proxy requests which include the complete URL in
+             the target portion of the HTTP request, or for "OPTIONS *" requests, which
+             cannot be expressed in a URL string.
+        headers: The request headers as a list of byte pairs.
+        stream: The request body as a bytes iterator.
+        extensions: An open ended dictionary, including optional extensions to the
+                    core request/response API. Keys may include:
+            timeout: A dictionary of str:Optional[float] timeout values.
+                     May include values for 'connect', 'read', 'write', or 'pool'.
+
+        Returns a tuple of:
+
+        status_code: The response status code as an integer. Should be in the range 1xx-5xx.
+        headers: The response headers as a list of byte pairs.
+        stream: The response body as a bytes iterator.
+        extensions: An open ended dictionary, including optional extensions to the
+                    core request/response API. Keys are plain strings, and may include:
+            reason_phrase: The reason-phrase of the HTTP response, as bytes. Eg b'OK'.
+                    HTTP/2 onwards does not include a reason phrase on the wire.
+                    When no key is included, a default based on the status code may
+                    be used. An empty-string reason phrase should not be substituted
+                    for a default, as it indicates the server left the portion blank
+                    eg. the leading response bytes were b"HTTP/1.1 200 <CRLF>".
+            http_version: The HTTP version, as bytes. Eg. b"HTTP/1.1".
+                    When no http_version key is included, HTTP/1.1 may be assumed.
+            close:  A callback which should be invoked to release any network
+                    resources.
+            aclose: An async callback which should be invoked to release any
+                    network resources.
+        """
+        raise NotImplementedError(
+            "The 'handle_request' method must be implemented."
+        )  # pragma: nocover
+
+    def close(self) -> None:
+        pass
+
+
+class AsyncBaseTransport:
+    async def __aenter__(self: A) -> A:
+        return self
+
+    async def __aexit__(
+        self,
+        exc_type: typing.Type[BaseException] = None,
+        exc_value: BaseException = None,
+        traceback: TracebackType = None,
+    ) -> None:
+        await self.aclose()
+
+    async def handle_async_request(
+        self,
+        method: bytes,
+        url: typing.Tuple[bytes, bytes, typing.Optional[int], bytes],
+        headers: typing.List[typing.Tuple[bytes, bytes]],
+        stream: typing.AsyncIterable[bytes],
+        extensions: dict,
+    ) -> typing.Tuple[
+        int, typing.List[typing.Tuple[bytes, bytes]], typing.AsyncIterable[bytes], dict
+    ]:
+        raise NotImplementedError(
+            "The 'handle_async_request' method must be implemented."
+        )  # pragma: nocover
+
+    async def aclose(self) -> None:
+        pass
index 84aeb26be889415539dac86dcf68c3d53837f795..67f62322afcc2a21f04b883447482befef5ac4b5 100644 (file)
@@ -24,21 +24,93 @@ client = httpx.Client(transport=transport)
 transport = httpx.HTTPTransport(uds="socket.uds")
 client = httpx.Client(transport=transport)
 """
+import contextlib
 import typing
 from types import TracebackType
 
 import httpcore
 
 from .._config import DEFAULT_LIMITS, Limits, Proxy, create_ssl_context
+from .._exceptions import (
+    CloseError,
+    ConnectError,
+    ConnectTimeout,
+    LocalProtocolError,
+    NetworkError,
+    PoolTimeout,
+    ProtocolError,
+    ProxyError,
+    ReadError,
+    ReadTimeout,
+    RemoteProtocolError,
+    TimeoutException,
+    UnsupportedProtocol,
+    WriteError,
+    WriteTimeout,
+)
 from .._types import CertTypes, VerifyTypes
+from .base import AsyncBaseTransport, BaseTransport
 
 T = typing.TypeVar("T", bound="HTTPTransport")
 A = typing.TypeVar("A", bound="AsyncHTTPTransport")
-Headers = typing.List[typing.Tuple[bytes, bytes]]
-URL = typing.Tuple[bytes, bytes, typing.Optional[int], bytes]
 
 
-class HTTPTransport(httpcore.SyncHTTPTransport):
+@contextlib.contextmanager
+def map_httpcore_exceptions() -> typing.Iterator[None]:
+    try:
+        yield
+    except Exception as exc:
+        mapped_exc = None
+
+        for from_exc, to_exc in HTTPCORE_EXC_MAP.items():
+            if not isinstance(exc, from_exc):
+                continue
+            # We want to map to the most specific exception we can find.
+            # Eg if `exc` is an `httpcore.ReadTimeout`, we want to map to
+            # `httpx.ReadTimeout`, not just `httpx.TimeoutException`.
+            if mapped_exc is None or issubclass(to_exc, mapped_exc):
+                mapped_exc = to_exc
+
+        if mapped_exc is None:  # pragma: nocover
+            raise
+
+        message = str(exc)
+        raise mapped_exc(message) from exc
+
+
+def ensure_http_version_reason_phrase_as_bytes(extensions: dict) -> None:
+    # From HTTPX 0.18 onwards we're treating the "reason_phrase" and "http_version"
+    # extensions as bytes, in order to be more precise. Also we're using the
+    # "reason_phrase" key in preference to "reason", in order to match properly
+    # with the HTTP spec naming.
+    # HTTPCore 0.12 does not yet use these same conventions for the extensions,
+    # so we bridge between the two styles for now.
+    if "reason" in extensions:
+        extensions["reason_phrase"] = extensions.pop("reason").encode("ascii")
+    if "http_version" in extensions:
+        extensions["http_version"] = extensions["http_version"].encode("ascii")
+
+
+HTTPCORE_EXC_MAP = {
+    httpcore.TimeoutException: TimeoutException,
+    httpcore.ConnectTimeout: ConnectTimeout,
+    httpcore.ReadTimeout: ReadTimeout,
+    httpcore.WriteTimeout: WriteTimeout,
+    httpcore.PoolTimeout: PoolTimeout,
+    httpcore.NetworkError: NetworkError,
+    httpcore.ConnectError: ConnectError,
+    httpcore.ReadError: ReadError,
+    httpcore.WriteError: WriteError,
+    httpcore.CloseError: CloseError,
+    httpcore.ProxyError: ProxyError,
+    httpcore.UnsupportedProtocol: UnsupportedProtocol,
+    httpcore.ProtocolError: ProtocolError,
+    httpcore.LocalProtocolError: LocalProtocolError,
+    httpcore.RemoteProtocolError: RemoteProtocolError,
+}
+
+
+class HTTPTransport(BaseTransport):
     def __init__(
         self,
         verify: VerifyTypes = True,
@@ -91,21 +163,44 @@ class HTTPTransport(httpcore.SyncHTTPTransport):
     ) -> None:
         self._pool.__exit__(exc_type, exc_value, traceback)
 
-    def request(
+    def handle_request(
         self,
         method: bytes,
-        url: URL,
-        headers: Headers = None,
-        stream: httpcore.SyncByteStream = None,
-        ext: dict = None,
-    ) -> typing.Tuple[int, Headers, httpcore.SyncByteStream, dict]:
-        return self._pool.request(method, url, headers=headers, stream=stream, ext=ext)
+        url: typing.Tuple[bytes, bytes, typing.Optional[int], bytes],
+        headers: typing.List[typing.Tuple[bytes, bytes]],
+        stream: typing.Iterable[bytes],
+        extensions: dict,
+    ) -> typing.Tuple[
+        int, typing.List[typing.Tuple[bytes, bytes]], typing.Iterable[bytes], dict
+    ]:
+        with map_httpcore_exceptions():
+            status_code, headers, byte_stream, extensions = self._pool.request(
+                method=method,
+                url=url,
+                headers=headers,
+                stream=httpcore.IteratorByteStream(iter(stream)),
+                ext=extensions,
+            )
+
+        def response_stream() -> typing.Iterator[bytes]:
+            with map_httpcore_exceptions():
+                for part in byte_stream:
+                    yield part
+
+        def close() -> None:
+            with map_httpcore_exceptions():
+                byte_stream.close()
+
+        ensure_http_version_reason_phrase_as_bytes(extensions)
+        extensions["close"] = close
+
+        return status_code, headers, response_stream(), extensions
 
     def close(self) -> None:
         self._pool.close()
 
 
-class AsyncHTTPTransport(httpcore.AsyncHTTPTransport):
+class AsyncHTTPTransport(AsyncBaseTransport):
     def __init__(
         self,
         verify: VerifyTypes = True,
@@ -158,17 +253,38 @@ class AsyncHTTPTransport(httpcore.AsyncHTTPTransport):
     ) -> None:
         await self._pool.__aexit__(exc_type, exc_value, traceback)
 
-    async def arequest(
+    async def handle_async_request(
         self,
         method: bytes,
-        url: URL,
-        headers: Headers = None,
-        stream: httpcore.AsyncByteStream = None,
-        ext: dict = None,
-    ) -> typing.Tuple[int, Headers, httpcore.AsyncByteStream, dict]:
-        return await self._pool.arequest(
-            method, url, headers=headers, stream=stream, ext=ext
-        )
+        url: typing.Tuple[bytes, bytes, typing.Optional[int], bytes],
+        headers: typing.List[typing.Tuple[bytes, bytes]],
+        stream: typing.AsyncIterable[bytes],
+        extensions: dict,
+    ) -> typing.Tuple[
+        int, typing.List[typing.Tuple[bytes, bytes]], typing.AsyncIterable[bytes], dict
+    ]:
+        with map_httpcore_exceptions():
+            status_code, headers, byte_stream, extenstions = await self._pool.arequest(
+                method=method,
+                url=url,
+                headers=headers,
+                stream=httpcore.AsyncIteratorByteStream(stream.__aiter__()),
+                ext=extensions,
+            )
+
+        async def response_stream() -> typing.AsyncIterator[bytes]:
+            with map_httpcore_exceptions():
+                async for part in byte_stream:
+                    yield part
+
+        async def aclose() -> None:
+            with map_httpcore_exceptions():
+                await byte_stream.aclose()
+
+        ensure_http_version_reason_phrase_as_bytes(extensions)
+        extensions["aclose"] = aclose
+
+        return status_code, headers, response_stream(), extensions
 
     async def aclose(self) -> None:
         await self._pool.aclose()
index a55a88b7a2705191a67c0e21a67fcc773a0c98a5..b6ca353a315800214b1a46b34edd1f4f9a0849a6 100644 (file)
@@ -1,23 +1,24 @@
 import asyncio
-from typing import Callable, List, Optional, Tuple
-
-import httpcore
+import typing
 
 from .._models import Request
+from .base import AsyncBaseTransport, BaseTransport
 
 
-class MockTransport(httpcore.SyncHTTPTransport, httpcore.AsyncHTTPTransport):
-    def __init__(self, handler: Callable) -> None:
+class MockTransport(AsyncBaseTransport, BaseTransport):
+    def __init__(self, handler: typing.Callable) -> None:
         self.handler = handler
 
-    def request(
+    def handle_request(
         self,
         method: bytes,
-        url: Tuple[bytes, bytes, Optional[int], bytes],
-        headers: List[Tuple[bytes, bytes]] = None,
-        stream: httpcore.SyncByteStream = None,
-        ext: dict = None,
-    ) -> Tuple[int, List[Tuple[bytes, bytes]], httpcore.SyncByteStream, dict]:
+        url: typing.Tuple[bytes, bytes, typing.Optional[int], bytes],
+        headers: typing.List[typing.Tuple[bytes, bytes]],
+        stream: typing.Iterable[bytes],
+        extensions: dict,
+    ) -> typing.Tuple[
+        int, typing.List[typing.Tuple[bytes, bytes]], typing.Iterable[bytes], dict
+    ]:
         request = Request(
             method=method,
             url=url,
@@ -30,17 +31,19 @@ class MockTransport(httpcore.SyncHTTPTransport, httpcore.AsyncHTTPTransport):
             response.status_code,
             response.headers.raw,
             response.stream,
-            response.ext,
+            response.extensions,
         )
 
-    async def arequest(
+    async def handle_async_request(
         self,
         method: bytes,
-        url: Tuple[bytes, bytes, Optional[int], bytes],
-        headers: List[Tuple[bytes, bytes]] = None,
-        stream: httpcore.AsyncByteStream = None,
-        ext: dict = None,
-    ) -> Tuple[int, List[Tuple[bytes, bytes]], httpcore.AsyncByteStream, dict]:
+        url: typing.Tuple[bytes, bytes, typing.Optional[int], bytes],
+        headers: typing.List[typing.Tuple[bytes, bytes]],
+        stream: typing.AsyncIterable[bytes],
+        extensions: dict,
+    ) -> typing.Tuple[
+        int, typing.List[typing.Tuple[bytes, bytes]], typing.AsyncIterable[bytes], dict
+    ]:
         request = Request(
             method=method,
             url=url,
@@ -63,5 +66,5 @@ class MockTransport(httpcore.SyncHTTPTransport, httpcore.AsyncHTTPTransport):
             response.status_code,
             response.headers.raw,
             response.stream,
-            response.ext,
+            response.extensions,
         )
index 67b44bde42f3f717637cc1e674ad15afb3fb99f2..3b7651fba71a000f53c50984fd4b7ea1be1e6632 100644 (file)
@@ -3,7 +3,7 @@ import itertools
 import typing
 from urllib.parse import unquote
 
-import httpcore
+from .base import BaseTransport
 
 
 def _skip_leading_empty_chunks(body: typing.Iterable) -> typing.Iterable:
@@ -14,7 +14,7 @@ def _skip_leading_empty_chunks(body: typing.Iterable) -> typing.Iterable:
     return []
 
 
-class WSGITransport(httpcore.SyncHTTPTransport):
+class WSGITransport(BaseTransport):
     """
     A custom transport that handles sending requests directly to an WSGI app.
     The simplest way to use this functionality is to use the `app` argument.
@@ -59,18 +59,17 @@ class WSGITransport(httpcore.SyncHTTPTransport):
         self.script_name = script_name
         self.remote_addr = remote_addr
 
-    def request(
+    def handle_request(
         self,
         method: bytes,
         url: typing.Tuple[bytes, bytes, typing.Optional[int], bytes],
-        headers: typing.List[typing.Tuple[bytes, bytes]] = None,
-        stream: httpcore.SyncByteStream = None,
-        ext: dict = None,
+        headers: typing.List[typing.Tuple[bytes, bytes]],
+        stream: typing.Iterable[bytes],
+        extensions: dict,
     ) -> typing.Tuple[
-        int, typing.List[typing.Tuple[bytes, bytes]], httpcore.SyncByteStream, dict
+        int, typing.List[typing.Tuple[bytes, bytes]], typing.Iterable[bytes], dict
     ]:
-        headers = [] if headers is None else headers
-        stream = httpcore.PlainByteStream(content=b"") if stream is None else stream
+        wsgi_input = io.BytesIO(b"".join(stream))
 
         scheme, host, port, full_path = url
         path, _, query = full_path.partition(b"?")
@@ -80,7 +79,7 @@ class WSGITransport(httpcore.SyncHTTPTransport):
         environ = {
             "wsgi.version": (1, 0),
             "wsgi.url_scheme": scheme.decode("ascii"),
-            "wsgi.input": io.BytesIO(b"".join(stream)),
+            "wsgi.input": wsgi_input,
             "wsgi.errors": io.BytesIO(),
             "wsgi.multithread": True,
             "wsgi.multiprocess": False,
@@ -126,7 +125,6 @@ class WSGITransport(httpcore.SyncHTTPTransport):
             (key.encode("ascii"), value.encode("ascii"))
             for key, value in seen_response_headers
         ]
-        stream = httpcore.IteratorByteStream(iterator=result)
-        ext = {}
+        extensions = {}
 
-        return (status_code, headers, stream, ext)
+        return (status_code, headers, result, extensions)
index 1d3f4ccafa6faa29f65ca9eece76127f47d28896..99493c43ab3cf74c5ecbec08a911112e4356d275 100644 (file)
@@ -1,7 +1,6 @@
 import typing
 from datetime import timedelta
 
-import httpcore
 import pytest
 
 import httpx
@@ -169,12 +168,12 @@ async def test_100_continue(server):
 
 @pytest.mark.usefixtures("async_environment")
 async def test_context_managed_transport():
-    class Transport(httpcore.AsyncHTTPTransport):
+    class Transport(httpx.AsyncBaseTransport):
         def __init__(self):
             self.events = []
 
         async def aclose(self):
-            # The base implementation of httpcore.AsyncHTTPTransport just
+            # The base implementation of httpx.AsyncBaseTransport just
             # calls into `.aclose`, so simple transport cases can just override
             # this method for any cleanup, where more complex cases
             # might want to additionally override `__aenter__`/`__aexit__`.
@@ -201,13 +200,13 @@ async def test_context_managed_transport():
 
 @pytest.mark.usefixtures("async_environment")
 async def test_context_managed_transport_and_mount():
-    class Transport(httpcore.AsyncHTTPTransport):
+    class Transport(httpx.AsyncBaseTransport):
         def __init__(self, name: str):
             self.name: str = name
             self.events: typing.List[str] = []
 
         async def aclose(self):
-            # The base implementation of httpcore.AsyncHTTPTransport just
+            # The base implementation of httpx.AsyncBaseTransport just
             # calls into `.aclose`, so simple transport cases can just override
             # this method for any cleanup, where more complex cases
             # might want to additionally override `__aenter__`/`__aexit__`.
@@ -303,25 +302,6 @@ async def test_mounted_transport():
         assert response.json() == {"app": "mounted"}
 
 
-@pytest.mark.usefixtures("async_environment")
-async def test_response_aclose_map_exceptions():
-    class BrokenStream:
-        async def __aiter__(self):
-            # so we're an AsyncIterator
-            pass  # pragma: nocover
-
-        async def aclose(self):
-            raise httpcore.CloseError(OSError(104, "Connection reset by peer"))
-
-    def handle(request: httpx.Request) -> httpx.Response:
-        return httpx.Response(200, stream=BrokenStream())
-
-    async with httpx.AsyncClient(transport=httpx.MockTransport(handle)) as client:
-        async with client.stream("GET", "http://example.com") as response:
-            with pytest.raises(httpx.CloseError):
-                await response.aclose()
-
-
 @pytest.mark.usefixtures("async_environment")
 async def test_async_mock_transport():
     async def hello_world(request):
index 7e32bcf6f3a53e42960e5f9ec921323ceee3d59f..386cd7480c6261e48a598b5d5a4384ed4365faf5 100644 (file)
@@ -1,7 +1,6 @@
 import typing
 from datetime import timedelta
 
-import httpcore
 import pytest
 
 import httpx
@@ -224,12 +223,12 @@ def test_pool_limits_deprecated():
 
 
 def test_context_managed_transport():
-    class Transport(httpcore.SyncHTTPTransport):
+    class Transport(httpx.BaseTransport):
         def __init__(self):
             self.events = []
 
         def close(self):
-            # The base implementation of httpcore.SyncHTTPTransport just
+            # The base implementation of httpx.BaseTransport just
             # calls into `.close`, so simple transport cases can just override
             # this method for any cleanup, where more complex cases
             # might want to additionally override `__enter__`/`__exit__`.
@@ -255,13 +254,13 @@ def test_context_managed_transport():
 
 
 def test_context_managed_transport_and_mount():
-    class Transport(httpcore.SyncHTTPTransport):
+    class Transport(httpx.BaseTransport):
         def __init__(self, name: str):
             self.name: str = name
             self.events: typing.List[str] = []
 
         def close(self):
-            # The base implementation of httpcore.SyncHTTPTransport just
+            # The base implementation of httpx.BaseTransport just
             # calls into `.close`, so simple transport cases can just override
             # this method for any cleanup, where more complex cases
             # might want to additionally override `__enter__`/`__exit__`.
index 84d371e9fa55c643dad9abd84daef0ec5c09c56f..22c5aa0f1ad2b84b7d7ddb656ccf7475a8e46db9 100644 (file)
@@ -1,4 +1,3 @@
-import httpcore
 import pytest
 
 import httpx
@@ -6,9 +5,7 @@ import httpx
 
 def redirects(request: httpx.Request) -> httpx.Response:
     if request.url.scheme not in ("http", "https"):
-        raise httpcore.UnsupportedProtocol(
-            f"Scheme {request.url.scheme!r} not supported."
-        )
+        raise httpx.UnsupportedProtocol(f"Scheme {request.url.scheme!r} not supported.")
 
     if request.url.path == "/redirect_301":
         status_code = httpx.codes.MOVED_PERMANENTLY
@@ -396,3 +393,10 @@ def test_redirect_custom_scheme():
     with pytest.raises(httpx.UnsupportedProtocol) as e:
         client.post("https://example.org/redirect_custom_scheme")
     assert str(e.value) == "Scheme 'market' not supported."
+
+
+@pytest.mark.usefixtures("async_environment")
+async def test_async_invalid_redirect():
+    async with httpx.AsyncClient(transport=httpx.MockTransport(redirects)) as client:
+        with pytest.raises(httpx.RemoteProtocolError):
+            await client.get("http://example.org/invalid_redirect")
index 12db1b0bb2ad8523bc2420db39ee9502927ad778..62c10c9fb4cf9acb51ca204e609721807af916be 100644 (file)
@@ -76,8 +76,6 @@ async def app(scope, receive, send):
     assert scope["type"] == "http"
     if scope["path"].startswith("/slow_response"):
         await slow_response(scope, receive, send)
-    elif scope["path"].startswith("/slow_stream_response"):
-        await slow_stream_response(scope, receive, send)
     elif scope["path"].startswith("/status"):
         await status_code(scope, receive, send)
     elif scope["path"].startswith("/echo_body"):
@@ -113,19 +111,6 @@ async def slow_response(scope, receive, send):
     await send({"type": "http.response.body", "body": b"Hello, world!"})
 
 
-async def slow_stream_response(scope, receive, send):
-    await send(
-        {
-            "type": "http.response.start",
-            "status": 200,
-            "headers": [[b"content-type", b"text/plain"]],
-        }
-    )
-
-    await sleep(1)
-    await send({"type": "http.response.body", "body": b"", "more_body": False})
-
-
 async def status_code(scope, receive, send):
     status_code = int(scope["path"].replace("/status/", ""))
     await send(
index cb46719c17d43b2cb7b5c30826de723ad32872f8..793fad3b76e7c2a1790144bd421054e1408fdea4 100644 (file)
@@ -733,7 +733,7 @@ def test_json_without_specified_encoding_value_error():
     # force incorrect guess from `guess_json_utf` to trigger error
     with mock.patch("httpx._models.guess_json_utf", return_value="utf-32"):
         response = httpx.Response(200, content=content, headers=headers)
-        with pytest.raises(ValueError):
+        with pytest.raises(json.decoder.JSONDecodeError):
             response.json()
 
 
@@ -767,7 +767,7 @@ def test_decode_error_with_request(header_value):
     headers = [(b"Content-Encoding", header_value)]
     body = b"test 123"
     compressed_body = brotli.compress(body)[3:]
-    with pytest.raises(ValueError):
+    with pytest.raises(httpx.DecodingError):
         httpx.Response(
             200,
             headers=headers,
@@ -788,7 +788,7 @@ def test_value_error_without_request(header_value):
     headers = [(b"Content-Encoding", header_value)]
     body = b"test 123"
     compressed_body = brotli.compress(body)[3:]
-    with pytest.raises(ValueError):
+    with pytest.raises(httpx.DecodingError):
         httpx.Response(200, headers=headers, content=compressed_body)
 
 
index b16f68246cd5070db606a9654c7bdef689445017..d7cf9412af2ba6c30c6965f0b9ba888a8c8635a8 100644 (file)
@@ -70,6 +70,42 @@ async def raise_exc_after_response(scope, receive, send):
     raise RuntimeError()
 
 
+async def empty_stream():
+    yield b""
+
+
+@pytest.mark.usefixtures("async_environment")
+async def test_asgi_transport():
+    async with httpx.ASGITransport(app=hello_world) as transport:
+        status_code, headers, stream, ext = await transport.handle_async_request(
+            method=b"GET",
+            url=(b"http", b"www.example.org", 80, b"/"),
+            headers=[(b"Host", b"www.example.org")],
+            stream=empty_stream(),
+            extensions={},
+        )
+        body = b"".join([part async for part in stream])
+
+        assert status_code == 200
+        assert body == b"Hello, World!"
+
+
+@pytest.mark.usefixtures("async_environment")
+async def test_asgi_transport_no_body():
+    async with httpx.ASGITransport(app=echo_body) as transport:
+        status_code, headers, stream, ext = await transport.handle_async_request(
+            method=b"GET",
+            url=(b"http", b"www.example.org", 80, b"/"),
+            headers=[(b"Host", b"www.example.org")],
+            stream=empty_stream(),
+            extensions={},
+        )
+        body = b"".join([part async for part in stream])
+
+        assert status_code == 200
+        assert body == b""
+
+
 @pytest.mark.usefixtures("async_environment")
 async def test_asgi():
     async with httpx.AsyncClient(app=hello_world) as client:
index f8c432cc8981dc5df79b398afe3a465255488a67..faaf71d2fb9461b08e2e5cb50767c284a4d7f7e8 100644 (file)
@@ -170,7 +170,7 @@ def test_decoding_errors(header_value):
         request = httpx.Request("GET", "https://example.org")
         httpx.Response(200, headers=headers, content=compressed_body, request=request)
 
-    with pytest.raises(ValueError):
+    with pytest.raises(httpx.DecodingError):
         httpx.Response(200, headers=headers, content=compressed_body)
 
 
index f1c7005bbae84494eb2ec60d609318f734aeea5f..1bc6723a879c44da88756cc028f5a45cff988fbc 100644 (file)
@@ -1,10 +1,10 @@
-from typing import Any
+from unittest import mock
 
 import httpcore
 import pytest
 
 import httpx
-from httpx._exceptions import HTTPCORE_EXC_MAP
+from httpx._transports.default import HTTPCORE_EXC_MAP
 
 
 def test_httpcore_all_exceptions_mapped() -> None:
@@ -29,25 +29,40 @@ def test_httpcore_exception_mapping(server) -> None:
     HTTPCore exception mapping works as expected.
     """
 
-    # Make sure we don't just map to `NetworkError`.
-    with pytest.raises(httpx.ConnectError):
-        httpx.get("http://doesnotexist")
+    def connect_failed(*args, **kwargs):
+        raise httpcore.ConnectError()
 
-    # Make sure streaming methods also map exceptions.
-    url = server.url.copy_with(path="/slow_stream_response")
-    timeout = httpx.Timeout(None, read=0.1)
-    with httpx.stream("GET", url, timeout=timeout) as stream:
-        with pytest.raises(httpx.ReadTimeout):
-            stream.read()
+    class TimeoutStream:
+        def __iter__(self):
+            raise httpcore.ReadTimeout()
+
+        def close(self):
+            pass
+
+    class CloseFailedStream:
+        def __iter__(self):
+            yield b""
 
-    # Make sure it also works with custom transports.
-    class MockTransport(httpcore.SyncHTTPTransport):
-        def request(self, *args: Any, **kwargs: Any) -> Any:
-            raise httpcore.ProtocolError()
+        def close(self):
+            raise httpcore.CloseError()
 
-    client = httpx.Client(transport=MockTransport())
-    with pytest.raises(httpx.ProtocolError):
-        client.get("http://testserver")
+    with mock.patch("httpcore.SyncConnectionPool.request", side_effect=connect_failed):
+        with pytest.raises(httpx.ConnectError):
+            httpx.get(server.url)
+
+    with mock.patch(
+        "httpcore.SyncConnectionPool.request",
+        return_value=(200, [], TimeoutStream(), {}),
+    ):
+        with pytest.raises(httpx.ReadTimeout):
+            httpx.get(server.url)
+
+    with mock.patch(
+        "httpcore.SyncConnectionPool.request",
+        return_value=(200, [], CloseFailedStream(), {}),
+    ):
+        with pytest.raises(httpx.CloseError):
+            httpx.get(server.url)
 
 
 def test_httpx_exceptions_exposed() -> None:
@@ -66,3 +81,15 @@ def test_httpx_exceptions_exposed() -> None:
 
     if not_exposed:  # pragma: nocover
         pytest.fail(f"Unexposed HTTPX exceptions: {not_exposed}")
+
+
+def test_request_attribute() -> None:
+    # Exception without request attribute
+    exc = httpx.ReadTimeout("Read operation timed out")
+    with pytest.raises(RuntimeError):
+        exc.request
+
+    # Exception with request attribute
+    request = httpx.Request("GET", "https://www.example.com")
+    exc = httpx.ReadTimeout("Read operation timed out", request=request)
+    assert exc.request == request