]> git.ipfire.org Git - thirdparty/httpx.git/commitdiff
Transport API as plain `request -> response` method. (#1840)
authorTom Christie <tom@tomchristie.com>
Mon, 13 Sep 2021 12:34:46 +0000 (13:34 +0100)
committerGitHub <noreply@github.com>
Mon, 13 Sep 2021 12:34:46 +0000 (13:34 +0100)
* Responses as context managers

* timeout -> request.extensions

* Transport API -> request/response signature

* Fix top-level httpx.stream()

* Drop response context manager methods

* Simplify ASGI tests

* Black formatting

14 files changed:
httpx/__init__.py
httpx/_client.py
httpx/_content.py
httpx/_models.py
httpx/_multipart.py
httpx/_transports/asgi.py
httpx/_transports/base.py
httpx/_transports/default.py
httpx/_transports/mock.py
httpx/_transports/wsgi.py
httpx/_types.py
tests/client/test_auth.py
tests/client/test_redirects.py
tests/test_asgi.py

index 4af3904fd36a038400a77f18b13f9d1e17e4b255..bfce57639fa03442e0cd38db2e8053c4635033b7 100644 (file)
@@ -37,15 +37,11 @@ from ._exceptions import (
 from ._models import URL, Cookies, Headers, QueryParams, Request, Response
 from ._status_codes import codes
 from ._transports.asgi import ASGITransport
-from ._transports.base import (
-    AsyncBaseTransport,
-    AsyncByteStream,
-    BaseTransport,
-    SyncByteStream,
-)
+from ._transports.base import AsyncBaseTransport, BaseTransport
 from ._transports.default import AsyncHTTPTransport, HTTPTransport
 from ._transports.mock import MockTransport
 from ._transports.wsgi import WSGITransport
+from ._types import AsyncByteStream, SyncByteStream
 
 __all__ = [
     "__description__",
index 6e8bb2f31dbd30f826ef3a88b8ec5586402f2b5e..7492cb45f238794f9e80e1a1e5995758a8af8dc2 100644 (file)
@@ -26,15 +26,11 @@ from ._exceptions import (
 from ._models import URL, Cookies, Headers, QueryParams, Request, Response
 from ._status_codes import codes
 from ._transports.asgi import ASGITransport
-from ._transports.base import (
-    AsyncBaseTransport,
-    AsyncByteStream,
-    BaseTransport,
-    SyncByteStream,
-)
+from ._transports.base import AsyncBaseTransport, BaseTransport
 from ._transports.default import AsyncHTTPTransport, HTTPTransport
 from ._transports.wsgi import WSGITransport
 from ._types import (
+    AsyncByteStream,
     AuthTypes,
     CertTypes,
     CookieTypes,
@@ -44,6 +40,7 @@ from ._types import (
     RequestContent,
     RequestData,
     RequestFiles,
+    SyncByteStream,
     TimeoutTypes,
     URLTypes,
     VerifyTypes,
@@ -327,6 +324,7 @@ class BaseClient:
         params: QueryParamTypes = None,
         headers: HeaderTypes = None,
         cookies: CookieTypes = None,
+        timeout: typing.Union[TimeoutTypes, UseClientDefault] = USE_CLIENT_DEFAULT,
     ) -> Request:
         """
         Build and return a request instance.
@@ -343,6 +341,9 @@ class BaseClient:
         headers = self._merge_headers(headers)
         cookies = self._merge_cookies(cookies)
         params = self._merge_queryparams(params)
+        timeout = (
+            self.timeout if isinstance(timeout, UseClientDefault) else Timeout(timeout)
+        )
         return Request(
             method,
             url,
@@ -353,6 +354,7 @@ class BaseClient:
             params=params,
             headers=headers,
             cookies=cookies,
+            extensions={"timeout": timeout.as_dict()},
         )
 
     def _merge_url(self, url: URLTypes) -> URL:
@@ -785,10 +787,9 @@ class Client(BaseClient):
             params=params,
             headers=headers,
             cookies=cookies,
+            timeout=timeout,
         )
-        return self.send(
-            request, auth=auth, follow_redirects=follow_redirects, timeout=timeout
-        )
+        return self.send(request, auth=auth, follow_redirects=follow_redirects)
 
     @contextmanager
     def stream(
@@ -827,12 +828,12 @@ class Client(BaseClient):
             params=params,
             headers=headers,
             cookies=cookies,
+            timeout=timeout,
         )
         response = self.send(
             request=request,
             auth=auth,
             follow_redirects=follow_redirects,
-            timeout=timeout,
             stream=True,
         )
         try:
@@ -847,7 +848,6 @@ class Client(BaseClient):
         stream: bool = False,
         auth: typing.Union[AuthTypes, UseClientDefault] = USE_CLIENT_DEFAULT,
         follow_redirects: typing.Union[bool, UseClientDefault] = USE_CLIENT_DEFAULT,
-        timeout: typing.Union[TimeoutTypes, UseClientDefault] = USE_CLIENT_DEFAULT,
     ) -> Response:
         """
         Send a request.
@@ -866,9 +866,6 @@ class Client(BaseClient):
             raise RuntimeError("Cannot send a request, as the client has been closed.")
 
         self._state = ClientState.OPENED
-        timeout = (
-            self.timeout if isinstance(timeout, UseClientDefault) else Timeout(timeout)
-        )
         follow_redirects = (
             self.follow_redirects
             if isinstance(follow_redirects, UseClientDefault)
@@ -880,7 +877,6 @@ class Client(BaseClient):
         response = self._send_handling_auth(
             request,
             auth=auth,
-            timeout=timeout,
             follow_redirects=follow_redirects,
             history=[],
         )
@@ -898,7 +894,6 @@ class Client(BaseClient):
         self,
         request: Request,
         auth: Auth,
-        timeout: Timeout,
         follow_redirects: bool,
         history: typing.List[Response],
     ) -> Response:
@@ -909,7 +904,6 @@ class Client(BaseClient):
             while True:
                 response = self._send_handling_redirects(
                     request,
-                    timeout=timeout,
                     follow_redirects=follow_redirects,
                     history=history,
                 )
@@ -933,7 +927,6 @@ class Client(BaseClient):
     def _send_handling_redirects(
         self,
         request: Request,
-        timeout: Timeout,
         follow_redirects: bool,
         history: typing.List[Response],
     ) -> Response:
@@ -946,7 +939,7 @@ class Client(BaseClient):
             for hook in self._event_hooks["request"]:
                 hook(request)
 
-            response = self._send_single_request(request, timeout)
+            response = self._send_single_request(request)
             try:
                 for hook in self._event_hooks["response"]:
                     hook(response)
@@ -968,7 +961,7 @@ class Client(BaseClient):
                 response.close()
                 raise exc
 
-    def _send_single_request(self, request: Request, timeout: Timeout) -> Response:
+    def _send_single_request(self, request: Request) -> Response:
         """
         Sends a single request, without handling any redirections.
         """
@@ -982,23 +975,14 @@ class Client(BaseClient):
             )
 
         with request_context(request=request):
-            (status_code, headers, stream, extensions) = transport.handle_request(
-                request.method.encode(),
-                request.url.raw,
-                headers=request.headers.raw,
-                stream=request.stream,
-                extensions={"timeout": timeout.as_dict()},
-            )
+            response = transport.handle_request(request)
 
-        response = Response(
-            status_code,
-            headers=headers,
-            stream=stream,
-            extensions=extensions,
-            request=request,
-        )
+        assert isinstance(response.stream, SyncByteStream)
 
-        response.stream = BoundSyncStream(stream, response=response, timer=timer)
+        response.request = request
+        response.stream = BoundSyncStream(
+            response.stream, response=response, timer=timer
+        )
         self.cookies.extract_cookies(response)
 
         status = f"{response.status_code} {response.reason_phrase}"
@@ -1494,9 +1478,10 @@ class AsyncClient(BaseClient):
             params=params,
             headers=headers,
             cookies=cookies,
+            timeout=timeout,
         )
         response = await self.send(
-            request, auth=auth, follow_redirects=follow_redirects, timeout=timeout
+            request, auth=auth, follow_redirects=follow_redirects
         )
         return response
 
@@ -1537,12 +1522,12 @@ class AsyncClient(BaseClient):
             params=params,
             headers=headers,
             cookies=cookies,
+            timeout=timeout,
         )
         response = await self.send(
             request=request,
             auth=auth,
             follow_redirects=follow_redirects,
-            timeout=timeout,
             stream=True,
         )
         try:
@@ -1557,7 +1542,6 @@ class AsyncClient(BaseClient):
         stream: bool = False,
         auth: typing.Union[AuthTypes, UseClientDefault] = USE_CLIENT_DEFAULT,
         follow_redirects: typing.Union[bool, UseClientDefault] = USE_CLIENT_DEFAULT,
-        timeout: typing.Union[TimeoutTypes, UseClientDefault] = USE_CLIENT_DEFAULT,
     ) -> Response:
         """
         Send a request.
@@ -1576,9 +1560,6 @@ class AsyncClient(BaseClient):
             raise RuntimeError("Cannot send a request, as the client has been closed.")
 
         self._state = ClientState.OPENED
-        timeout = (
-            self.timeout if isinstance(timeout, UseClientDefault) else Timeout(timeout)
-        )
         follow_redirects = (
             self.follow_redirects
             if isinstance(follow_redirects, UseClientDefault)
@@ -1590,7 +1571,6 @@ class AsyncClient(BaseClient):
         response = await self._send_handling_auth(
             request,
             auth=auth,
-            timeout=timeout,
             follow_redirects=follow_redirects,
             history=[],
         )
@@ -1608,7 +1588,6 @@ class AsyncClient(BaseClient):
         self,
         request: Request,
         auth: Auth,
-        timeout: Timeout,
         follow_redirects: bool,
         history: typing.List[Response],
     ) -> Response:
@@ -1619,7 +1598,6 @@ class AsyncClient(BaseClient):
             while True:
                 response = await self._send_handling_redirects(
                     request,
-                    timeout=timeout,
                     follow_redirects=follow_redirects,
                     history=history,
                 )
@@ -1643,7 +1621,6 @@ class AsyncClient(BaseClient):
     async def _send_handling_redirects(
         self,
         request: Request,
-        timeout: Timeout,
         follow_redirects: bool,
         history: typing.List[Response],
     ) -> Response:
@@ -1656,7 +1633,7 @@ class AsyncClient(BaseClient):
             for hook in self._event_hooks["request"]:
                 await hook(request)
 
-            response = await self._send_single_request(request, timeout)
+            response = await self._send_single_request(request)
             try:
                 for hook in self._event_hooks["response"]:
                     await hook(response)
@@ -1679,9 +1656,7 @@ class AsyncClient(BaseClient):
                 await response.aclose()
                 raise exc
 
-    async def _send_single_request(
-        self, request: Request, timeout: Timeout
-    ) -> Response:
+    async def _send_single_request(self, request: Request) -> Response:
         """
         Sends a single request, without handling any redirections.
         """
@@ -1695,28 +1670,13 @@ class AsyncClient(BaseClient):
             )
 
         with request_context(request=request):
-            (
-                status_code,
-                headers,
-                stream,
-                extensions,
-            ) = await transport.handle_async_request(
-                request.method.encode(),
-                request.url.raw,
-                headers=request.headers.raw,
-                stream=request.stream,
-                extensions={"timeout": timeout.as_dict()},
-            )
+            response = await transport.handle_async_request(request)
 
-        response = Response(
-            status_code,
-            headers=headers,
-            stream=stream,
-            extensions=extensions,
-            request=request,
+        assert isinstance(response.stream, AsyncByteStream)
+        response.request = request
+        response.stream = BoundAsyncStream(
+            response.stream, response=response, timer=timer
         )
-
-        response.stream = BoundAsyncStream(stream, response=response, timer=timer)
         self.cookies.extract_cookies(response)
 
         status = f"{response.status_code} {response.reason_phrase}"
index 86f3c7c254109afc5e3967ca1c92ed0bc4d885b9..d7e8aa097493af731b5bdea56e2bcfdedc0ad388 100644 (file)
@@ -15,8 +15,14 @@ from urllib.parse import urlencode
 
 from ._exceptions import StreamClosed, StreamConsumed
 from ._multipart import MultipartStream
-from ._transports.base import AsyncByteStream, SyncByteStream
-from ._types import RequestContent, RequestData, RequestFiles, ResponseContent
+from ._types import (
+    AsyncByteStream,
+    RequestContent,
+    RequestData,
+    RequestFiles,
+    ResponseContent,
+    SyncByteStream,
+)
 from ._utils import peek_filelike_length, primitive_value_to_str
 
 
index 0a54a6fa6825d6342258a435ec1c7e961d2f4d3a..7c6460e73fda35560df88061a31334bf5697bd1e 100644 (file)
@@ -35,8 +35,8 @@ from ._exceptions import (
     request_context,
 )
 from ._status_codes import codes
-from ._transports.base import AsyncByteStream, SyncByteStream
 from ._types import (
+    AsyncByteStream,
     CookieTypes,
     HeaderTypes,
     PrimitiveData,
@@ -46,6 +46,7 @@ from ._types import (
     RequestData,
     RequestFiles,
     ResponseContent,
+    SyncByteStream,
     URLTypes,
 )
 from ._utils import (
@@ -1081,15 +1082,19 @@ class Request:
         files: RequestFiles = None,
         json: typing.Any = None,
         stream: typing.Union[SyncByteStream, AsyncByteStream] = None,
+        extensions: dict = None,
     ):
-        if isinstance(method, bytes):
-            self.method = method.decode("ascii").upper()
-        else:
-            self.method = method.upper()
+        self.method = (
+            method.decode("ascii").upper()
+            if isinstance(method, bytes)
+            else method.upper()
+        )
         self.url = URL(url)
         if params is not None:
             self.url = self.url.copy_merge_params(params=params)
         self.headers = Headers(headers)
+        self.extensions = {} if extensions is None else extensions
+
         if cookies:
             Cookies(cookies).set_cookie_header(self)
 
index 683e6f1311a881006f014123a1223c42fb8c2c48..4dfb838a68109ab296f32d85ce812b6273675f68 100644 (file)
@@ -4,8 +4,13 @@ import os
 import typing
 from pathlib import Path
 
-from ._transports.base import AsyncByteStream, SyncByteStream
-from ._types import FileContent, FileTypes, RequestFiles
+from ._types import (
+    AsyncByteStream,
+    FileContent,
+    FileTypes,
+    RequestFiles,
+    SyncByteStream,
+)
 from ._utils import (
     format_form_param,
     guess_content_type,
index 24c5452dc921ea3ea5a37619c6e591c24942872e..4e3616587dfcf2afd0c6e217c4e56a020f661682 100644 (file)
@@ -1,9 +1,10 @@
 import typing
-from urllib.parse import unquote
 
 import sniffio
 
-from .base import AsyncBaseTransport, AsyncByteStream
+from .._models import Request, Response
+from .._types import AsyncByteStream
+from .base import AsyncBaseTransport
 
 if typing.TYPE_CHECKING:  # pragma: no cover
     import asyncio
@@ -79,34 +80,28 @@ class ASGITransport(AsyncBaseTransport):
 
     async def handle_async_request(
         self,
-        method: bytes,
-        url: typing.Tuple[bytes, bytes, typing.Optional[int], bytes],
-        headers: typing.List[typing.Tuple[bytes, bytes]],
-        stream: AsyncByteStream,
-        extensions: dict,
-    ) -> typing.Tuple[
-        int, typing.List[typing.Tuple[bytes, bytes]], AsyncByteStream, dict
-    ]:
+        request: Request,
+    ) -> Response:
+        assert isinstance(request.stream, AsyncByteStream)
+
         # ASGI scope.
-        scheme, host, port, full_path = url
-        path, _, query = full_path.partition(b"?")
         scope = {
             "type": "http",
             "asgi": {"version": "3.0"},
             "http_version": "1.1",
-            "method": method.decode(),
-            "headers": [(k.lower(), v) for (k, v) in headers],
-            "scheme": scheme.decode("ascii"),
-            "path": unquote(path.decode("ascii")),
-            "raw_path": path,
-            "query_string": query,
-            "server": (host.decode("ascii"), port),
+            "method": request.method,
+            "headers": [(k.lower(), v) for (k, v) in request.headers.raw],
+            "scheme": request.url.scheme,
+            "path": request.url.path,
+            "raw_path": request.url.raw_path,
+            "query_string": request.url.query,
+            "server": (request.url.host, request.url.port),
             "client": self.client,
             "root_path": self.root_path,
         }
 
         # Request.
-        request_body_chunks = stream.__aiter__()
+        request_body_chunks = request.stream.__aiter__()
         request_complete = False
 
         # Response.
@@ -147,7 +142,7 @@ class ASGITransport(AsyncBaseTransport):
                 body = message.get("body", b"")
                 more_body = message.get("more_body", False)
 
-                if body and method != b"HEAD":
+                if body and request.method != "HEAD":
                     body_parts.append(body)
 
                 if not more_body:
@@ -164,6 +159,5 @@ class ASGITransport(AsyncBaseTransport):
         assert response_headers is not None
 
         stream = ASGIResponseStream(body_parts)
-        extensions = {}
 
-        return (status_code, response_headers, stream, extensions)
+        return Response(status_code, headers=response_headers, stream=stream)
index eb519269704882d424f6e7c4c2d64bef5d6b668b..8c324ab4cd3a7c0732a8043a92690c03d6b4663c 100644 (file)
@@ -1,67 +1,12 @@
 import typing
 from types import TracebackType
 
+from .._models import Request, Response
+
 T = typing.TypeVar("T", bound="BaseTransport")
 A = typing.TypeVar("A", bound="AsyncBaseTransport")
 
 
-class SyncByteStream:
-    def __iter__(self) -> typing.Iterator[bytes]:
-        raise NotImplementedError(
-            "The '__iter__' method must be implemented."
-        )  # pragma: nocover
-        yield b""  # pragma: nocover
-
-    def close(self) -> None:
-        """
-        Subclasses can override this method to release any network resources
-        after a request/response cycle is complete.
-
-        Streaming cases should use a `try...finally` block to ensure that
-        the stream `close()` method is always called.
-
-        Example:
-
-            status_code, headers, stream, extensions = transport.handle_request(...)
-            try:
-                ...
-            finally:
-                stream.close()
-        """
-
-    def read(self) -> bytes:
-        """
-        Simple cases can use `.read()` as a convience method for consuming
-        the entire stream and then closing it.
-
-        Example:
-
-            status_code, headers, stream, extensions = transport.handle_request(...)
-            body = stream.read()
-        """
-        try:
-            return b"".join([part for part in self])
-        finally:
-            self.close()
-
-
-class AsyncByteStream:
-    async def __aiter__(self) -> typing.AsyncIterator[bytes]:
-        raise NotImplementedError(
-            "The '__aiter__' method must be implemented."
-        )  # pragma: nocover
-        yield b""  # pragma: nocover
-
-    async def aclose(self) -> None:
-        pass
-
-    async def aread(self) -> bytes:
-        try:
-            return b"".join([part async for part in self])
-        finally:
-            await self.aclose()
-
-
 class BaseTransport:
     def __enter__(self: T) -> T:
         return self
@@ -74,16 +19,7 @@ class BaseTransport:
     ) -> None:
         self.close()
 
-    def handle_request(
-        self,
-        method: bytes,
-        url: typing.Tuple[bytes, bytes, typing.Optional[int], bytes],
-        headers: typing.List[typing.Tuple[bytes, bytes]],
-        stream: SyncByteStream,
-        extensions: dict,
-    ) -> typing.Tuple[
-        int, typing.List[typing.Tuple[bytes, bytes]], SyncByteStream, dict
-    ]:
+    def handle_request(self, request: Request) -> Response:
         """
         Send a single HTTP request and return a response.
 
@@ -167,14 +103,8 @@ class AsyncBaseTransport:
 
     async def handle_async_request(
         self,
-        method: bytes,
-        url: typing.Tuple[bytes, bytes, typing.Optional[int], bytes],
-        headers: typing.List[typing.Tuple[bytes, bytes]],
-        stream: AsyncByteStream,
-        extensions: dict,
-    ) -> typing.Tuple[
-        int, typing.List[typing.Tuple[bytes, bytes]], AsyncByteStream, dict
-    ]:
+        request: Request,
+    ) -> Response:
         raise NotImplementedError(
             "The 'handle_async_request' method must be implemented."
         )  # pragma: nocover
index 73401fce66c8b793acc383307f7279fef1a29c7c..2566a3f217a27603c8f493c6bdca34d97a03582e 100644 (file)
@@ -48,8 +48,9 @@ from .._exceptions import (
     WriteError,
     WriteTimeout,
 )
-from .._types import CertTypes, VerifyTypes
-from .base import AsyncBaseTransport, AsyncByteStream, BaseTransport, SyncByteStream
+from .._models import Request, Response
+from .._types import AsyncByteStream, CertTypes, SyncByteStream, VerifyTypes
+from .base import AsyncBaseTransport, BaseTransport
 
 T = typing.TypeVar("T", bound="HTTPTransport")
 A = typing.TypeVar("A", bound="AsyncHTTPTransport")
@@ -168,26 +169,24 @@ class HTTPTransport(BaseTransport):
 
     def handle_request(
         self,
-        method: bytes,
-        url: typing.Tuple[bytes, bytes, typing.Optional[int], bytes],
-        headers: typing.List[typing.Tuple[bytes, bytes]],
-        stream: SyncByteStream,
-        extensions: dict,
-    ) -> typing.Tuple[
-        int, typing.List[typing.Tuple[bytes, bytes]], SyncByteStream, dict
-    ]:
+        request: Request,
+    ) -> Response:
+        assert isinstance(request.stream, SyncByteStream)
+
         with map_httpcore_exceptions():
             status_code, headers, byte_stream, extensions = self._pool.handle_request(
-                method=method,
-                url=url,
-                headers=headers,
-                stream=httpcore.IteratorByteStream(iter(stream)),
-                extensions=extensions,
+                method=request.method.encode("ascii"),
+                url=request.url.raw,
+                headers=request.headers.raw,
+                stream=httpcore.IteratorByteStream(iter(request.stream)),
+                extensions=request.extensions,
             )
 
         stream = ResponseStream(byte_stream)
 
-        return status_code, headers, stream, extensions
+        return Response(
+            status_code, headers=headers, stream=stream, extensions=extensions
+        )
 
     def close(self) -> None:
         self._pool.close()
@@ -264,14 +263,10 @@ class AsyncHTTPTransport(AsyncBaseTransport):
 
     async def handle_async_request(
         self,
-        method: bytes,
-        url: typing.Tuple[bytes, bytes, typing.Optional[int], bytes],
-        headers: typing.List[typing.Tuple[bytes, bytes]],
-        stream: AsyncByteStream,
-        extensions: dict,
-    ) -> typing.Tuple[
-        int, typing.List[typing.Tuple[bytes, bytes]], AsyncByteStream, dict
-    ]:
+        request: Request,
+    ) -> Response:
+        assert isinstance(request.stream, AsyncByteStream)
+
         with map_httpcore_exceptions():
             (
                 status_code,
@@ -279,16 +274,18 @@ class AsyncHTTPTransport(AsyncBaseTransport):
                 byte_stream,
                 extensions,
             ) = await self._pool.handle_async_request(
-                method=method,
-                url=url,
-                headers=headers,
-                stream=httpcore.AsyncIteratorByteStream(stream.__aiter__()),
-                extensions=extensions,
+                method=request.method.encode("ascii"),
+                url=request.url.raw,
+                headers=request.headers.raw,
+                stream=httpcore.AsyncIteratorByteStream(request.stream.__aiter__()),
+                extensions=request.extensions,
             )
 
         stream = AsyncResponseStream(byte_stream)
 
-        return status_code, headers, stream, extensions
+        return Response(
+            status_code, headers=headers, stream=stream, extensions=extensions
+        )
 
     async def aclose(self) -> None:
         await self._pool.aclose()
index 8d59b73820d87a85ee7ebfef2097e1e1e71a7cdc..f61aee710114cb0b3740abd7c55eba2dfc5d2a92 100644 (file)
@@ -1,8 +1,8 @@
 import asyncio
 import typing
 
-from .._models import Request
-from .base import AsyncBaseTransport, AsyncByteStream, BaseTransport, SyncByteStream
+from .._models import Request, Response
+from .base import AsyncBaseTransport, BaseTransport
 
 
 class MockTransport(AsyncBaseTransport, BaseTransport):
@@ -11,47 +11,16 @@ class MockTransport(AsyncBaseTransport, BaseTransport):
 
     def handle_request(
         self,
-        method: bytes,
-        url: typing.Tuple[bytes, bytes, typing.Optional[int], bytes],
-        headers: typing.List[typing.Tuple[bytes, bytes]],
-        stream: SyncByteStream,
-        extensions: dict,
-    ) -> typing.Tuple[
-        int, typing.List[typing.Tuple[bytes, bytes]], SyncByteStream, dict
-    ]:
-        request = Request(
-            method=method,
-            url=url,
-            headers=headers,
-            stream=stream,
-        )
+        request: Request,
+    ) -> Response:
         request.read()
-        response = self.handler(request)
-        return (
-            response.status_code,
-            response.headers.raw,
-            response.stream,
-            response.extensions,
-        )
+        return self.handler(request)
 
     async def handle_async_request(
         self,
-        method: bytes,
-        url: typing.Tuple[bytes, bytes, typing.Optional[int], bytes],
-        headers: typing.List[typing.Tuple[bytes, bytes]],
-        stream: AsyncByteStream,
-        extensions: dict,
-    ) -> typing.Tuple[
-        int, typing.List[typing.Tuple[bytes, bytes]], AsyncByteStream, dict
-    ]:
-        request = Request(
-            method=method,
-            url=url,
-            headers=headers,
-            stream=stream,
-        )
+        request: Request,
+    ) -> Response:
         await request.aread()
-
         response = self.handler(request)
 
         # Allow handler to *optionally* be an `async` function.
@@ -62,9 +31,4 @@ class MockTransport(AsyncBaseTransport, BaseTransport):
         if asyncio.iscoroutine(response):
             response = await response
 
-        return (
-            response.status_code,
-            response.headers.raw,
-            response.stream,
-            response.extensions,
-        )
+        return response
index e8bdfd3f687c1e3bbfcb617191d558bbd1bc2b24..3dedf49f96af8f4176f224a9c7322a0ec2008dae 100644 (file)
@@ -2,9 +2,10 @@ import io
 import itertools
 import sys
 import typing
-from urllib.parse import unquote
 
-from .base import BaseTransport, SyncByteStream
+from .._models import Request, Response
+from .._types import SyncByteStream
+from .base import BaseTransport
 
 
 def _skip_leading_empty_chunks(body: typing.Iterable) -> typing.Iterable:
@@ -76,40 +77,28 @@ class WSGITransport(BaseTransport):
         self.remote_addr = remote_addr
         self.wsgi_errors = wsgi_errors
 
-    def handle_request(
-        self,
-        method: bytes,
-        url: typing.Tuple[bytes, bytes, typing.Optional[int], bytes],
-        headers: typing.List[typing.Tuple[bytes, bytes]],
-        stream: SyncByteStream,
-        extensions: dict,
-    ) -> typing.Tuple[
-        int, typing.List[typing.Tuple[bytes, bytes]], SyncByteStream, dict
-    ]:
-        wsgi_input = io.BytesIO(b"".join(stream))
-
-        scheme, host, port, full_path = url
-        path, _, query = full_path.partition(b"?")
-        if port is None:
-            port = {b"http": 80, b"https": 443}[scheme]
+    def handle_request(self, request: Request) -> Response:
+        request.read()
+        wsgi_input = io.BytesIO(request.content)
 
+        port = request.url.port or {"http": 80, "https": 443}[request.url.scheme]
         environ = {
             "wsgi.version": (1, 0),
-            "wsgi.url_scheme": scheme.decode("ascii"),
+            "wsgi.url_scheme": request.url.scheme,
             "wsgi.input": wsgi_input,
             "wsgi.errors": self.wsgi_errors or sys.stderr,
             "wsgi.multithread": True,
             "wsgi.multiprocess": False,
             "wsgi.run_once": False,
-            "REQUEST_METHOD": method.decode(),
+            "REQUEST_METHOD": request.method,
             "SCRIPT_NAME": self.script_name,
-            "PATH_INFO": unquote(path.decode("ascii")),
-            "QUERY_STRING": query.decode("ascii"),
-            "SERVER_NAME": host.decode("ascii"),
+            "PATH_INFO": request.url.path,
+            "QUERY_STRING": request.url.query.decode("ascii"),
+            "SERVER_NAME": request.url.host,
             "SERVER_PORT": str(port),
             "REMOTE_ADDR": self.remote_addr,
         }
-        for header_key, header_value in headers:
+        for header_key, header_value in request.headers.raw:
             key = header_key.decode("ascii").upper().replace("-", "_")
             if key not in ("CONTENT_TYPE", "CONTENT_LENGTH"):
                 key = "HTTP_" + key
@@ -141,6 +130,5 @@ class WSGITransport(BaseTransport):
             (key.encode("ascii"), value.encode("ascii"))
             for key, value in seen_response_headers
         ]
-        extensions = {}
 
-        return (status_code, headers, stream, extensions)
+        return Response(status_code, headers=headers, stream=stream)
index 2381996c01bed3f98fa834dcdcb9ea71d0b10916..71a97a26360762f7ed270a67bcfe30aa97aab749 100644 (file)
@@ -8,9 +8,11 @@ from typing import (
     IO,
     TYPE_CHECKING,
     AsyncIterable,
+    AsyncIterator,
     Callable,
     Dict,
     Iterable,
+    Iterator,
     List,
     Mapping,
     Optional,
@@ -89,3 +91,60 @@ FileTypes = Union[
     Tuple[Optional[str], FileContent, Optional[str]],
 ]
 RequestFiles = Union[Mapping[str, FileTypes], Sequence[Tuple[str, FileTypes]]]
+
+
+class SyncByteStream:
+    def __iter__(self) -> Iterator[bytes]:
+        raise NotImplementedError(
+            "The '__iter__' method must be implemented."
+        )  # pragma: nocover
+        yield b""  # pragma: nocover
+
+    def close(self) -> None:
+        """
+        Subclasses can override this method to release any network resources
+        after a request/response cycle is complete.
+
+        Streaming cases should use a `try...finally` block to ensure that
+        the stream `close()` method is always called.
+
+        Example:
+
+            status_code, headers, stream, extensions = transport.handle_request(...)
+            try:
+                ...
+            finally:
+                stream.close()
+        """
+
+    def read(self) -> bytes:
+        """
+        Simple cases can use `.read()` as a convience method for consuming
+        the entire stream and then closing it.
+
+        Example:
+
+            status_code, headers, stream, extensions = transport.handle_request(...)
+            body = stream.read()
+        """
+        try:
+            return b"".join([part for part in self])
+        finally:
+            self.close()
+
+
+class AsyncByteStream:
+    async def __aiter__(self) -> AsyncIterator[bytes]:
+        raise NotImplementedError(
+            "The '__aiter__' method must be implemented."
+        )  # pragma: nocover
+        yield b""  # pragma: nocover
+
+    async def aclose(self) -> None:
+        pass
+
+    async def aread(self) -> bytes:
+        try:
+            return b"".join([part async for part in self])
+        finally:
+            await self.aclose()
index b6cb42d0bb67a7e92062f736bf7e9d6ea5b6ba69..8caaeb5d498375cd0c1cb2bda780df3026a42d22 100644 (file)
@@ -619,6 +619,13 @@ def test_sync_auth_history() -> None:
     assert len(resp1.history) == 0
 
 
+class ConsumeBodyTransport(httpx.MockTransport):
+    async def handle_async_request(self, request: Request) -> Response:
+        assert isinstance(request.stream, httpx.AsyncByteStream)
+        [_ async for _ in request.stream]
+        return self.handler(request)
+
+
 @pytest.mark.asyncio
 async def test_digest_auth_unavailable_streaming_body():
     url = "https://example.org/"
@@ -628,7 +635,7 @@ async def test_digest_auth_unavailable_streaming_body():
     async def streaming_body():
         yield b"Example request body"  # pragma: nocover
 
-    async with httpx.AsyncClient(transport=httpx.MockTransport(app)) as client:
+    async with httpx.AsyncClient(transport=ConsumeBodyTransport(app)) as client:
         with pytest.raises(httpx.StreamConsumed):
             await client.post(url, content=streaming_body(), auth=auth)
 
index 87d9cdfa7eda1e4f22ccd42a80b5cb98011a1755..adc3aae38813cb9e5a22f6e43e33ed58d1662d70 100644 (file)
@@ -317,13 +317,20 @@ def test_can_stream_if_no_redirect():
     client = httpx.Client(transport=httpx.MockTransport(redirects))
     url = "https://example.org/redirect_301"
     with client.stream("GET", url, follow_redirects=False) as response:
-        assert not response.is_closed
+        pass
     assert response.status_code == httpx.codes.MOVED_PERMANENTLY
     assert response.headers["location"] == "https://example.org/"
 
 
+class ConsumeBodyTransport(httpx.MockTransport):
+    def handle_request(self, request: httpx.Request) -> httpx.Response:
+        assert isinstance(request.stream, httpx.SyncByteStream)
+        [_ for _ in request.stream]
+        return self.handler(request)
+
+
 def test_cannot_redirect_streaming_body():
-    client = httpx.Client(transport=httpx.MockTransport(redirects))
+    client = httpx.Client(transport=ConsumeBodyTransport(redirects))
     url = "https://example.org/redirect_body"
 
     def streaming_body():
index d7cf9412af2ba6c30c6965f0b9ba888a8c8635a8..60f55dfd6fd7b9ddec624e014d905e45f6accb12 100644 (file)
@@ -70,40 +70,24 @@ async def raise_exc_after_response(scope, receive, send):
     raise RuntimeError()
 
 
-async def empty_stream():
-    yield b""
-
-
 @pytest.mark.usefixtures("async_environment")
 async def test_asgi_transport():
     async with httpx.ASGITransport(app=hello_world) as transport:
-        status_code, headers, stream, ext = await transport.handle_async_request(
-            method=b"GET",
-            url=(b"http", b"www.example.org", 80, b"/"),
-            headers=[(b"Host", b"www.example.org")],
-            stream=empty_stream(),
-            extensions={},
-        )
-        body = b"".join([part async for part in stream])
-
-        assert status_code == 200
-        assert body == b"Hello, World!"
+        request = httpx.Request("GET", "http://www.example.com/")
+        response = await transport.handle_async_request(request)
+        await response.aread()
+        assert response.status_code == 200
+        assert response.content == b"Hello, World!"
 
 
 @pytest.mark.usefixtures("async_environment")
 async def test_asgi_transport_no_body():
     async with httpx.ASGITransport(app=echo_body) as transport:
-        status_code, headers, stream, ext = await transport.handle_async_request(
-            method=b"GET",
-            url=(b"http", b"www.example.org", 80, b"/"),
-            headers=[(b"Host", b"www.example.org")],
-            stream=empty_stream(),
-            extensions={},
-        )
-        body = b"".join([part async for part in stream])
-
-        assert status_code == 200
-        assert body == b""
+        request = httpx.Request("GET", "http://www.example.com/")
+        response = await transport.handle_async_request(request)
+        await response.aread()
+        assert response.status_code == 200
+        assert response.content == b""
 
 
 @pytest.mark.usefixtures("async_environment")