From: Tom Christie Date: Mon, 13 Sep 2021 12:34:46 +0000 (+0100) Subject: Transport API as plain `request -> response` method. (#1840) X-Git-Tag: 1.0.0.beta0~2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=ff9813e84dab56f0f3c4ef3a159a4cce8c644a91;p=thirdparty%2Fhttpx.git Transport API as plain `request -> response` method. (#1840) * Responses as context managers * timeout -> request.extensions * Transport API -> request/response signature * Fix top-level httpx.stream() * Drop response context manager methods * Simplify ASGI tests * Black formatting --- diff --git a/httpx/__init__.py b/httpx/__init__.py index 4af3904f..bfce5763 100644 --- a/httpx/__init__.py +++ b/httpx/__init__.py @@ -37,15 +37,11 @@ from ._exceptions import ( from ._models import URL, Cookies, Headers, QueryParams, Request, Response from ._status_codes import codes from ._transports.asgi import ASGITransport -from ._transports.base import ( - AsyncBaseTransport, - AsyncByteStream, - BaseTransport, - SyncByteStream, -) +from ._transports.base import AsyncBaseTransport, BaseTransport from ._transports.default import AsyncHTTPTransport, HTTPTransport from ._transports.mock import MockTransport from ._transports.wsgi import WSGITransport +from ._types import AsyncByteStream, SyncByteStream __all__ = [ "__description__", diff --git a/httpx/_client.py b/httpx/_client.py index 6e8bb2f3..7492cb45 100644 --- a/httpx/_client.py +++ b/httpx/_client.py @@ -26,15 +26,11 @@ from ._exceptions import ( from ._models import URL, Cookies, Headers, QueryParams, Request, Response from ._status_codes import codes from ._transports.asgi import ASGITransport -from ._transports.base import ( - AsyncBaseTransport, - AsyncByteStream, - BaseTransport, - SyncByteStream, -) +from ._transports.base import AsyncBaseTransport, BaseTransport from ._transports.default import AsyncHTTPTransport, HTTPTransport from ._transports.wsgi import WSGITransport from ._types import ( + AsyncByteStream, AuthTypes, CertTypes, CookieTypes, @@ -44,6 +40,7 @@ from ._types import ( RequestContent, RequestData, RequestFiles, + SyncByteStream, TimeoutTypes, URLTypes, VerifyTypes, @@ -327,6 +324,7 @@ class BaseClient: params: QueryParamTypes = None, headers: HeaderTypes = None, cookies: CookieTypes = None, + timeout: typing.Union[TimeoutTypes, UseClientDefault] = USE_CLIENT_DEFAULT, ) -> Request: """ Build and return a request instance. @@ -343,6 +341,9 @@ class BaseClient: headers = self._merge_headers(headers) cookies = self._merge_cookies(cookies) params = self._merge_queryparams(params) + timeout = ( + self.timeout if isinstance(timeout, UseClientDefault) else Timeout(timeout) + ) return Request( method, url, @@ -353,6 +354,7 @@ class BaseClient: params=params, headers=headers, cookies=cookies, + extensions={"timeout": timeout.as_dict()}, ) def _merge_url(self, url: URLTypes) -> URL: @@ -785,10 +787,9 @@ class Client(BaseClient): params=params, headers=headers, cookies=cookies, + timeout=timeout, ) - return self.send( - request, auth=auth, follow_redirects=follow_redirects, timeout=timeout - ) + return self.send(request, auth=auth, follow_redirects=follow_redirects) @contextmanager def stream( @@ -827,12 +828,12 @@ class Client(BaseClient): params=params, headers=headers, cookies=cookies, + timeout=timeout, ) response = self.send( request=request, auth=auth, follow_redirects=follow_redirects, - timeout=timeout, stream=True, ) try: @@ -847,7 +848,6 @@ class Client(BaseClient): stream: bool = False, auth: typing.Union[AuthTypes, UseClientDefault] = USE_CLIENT_DEFAULT, follow_redirects: typing.Union[bool, UseClientDefault] = USE_CLIENT_DEFAULT, - timeout: typing.Union[TimeoutTypes, UseClientDefault] = USE_CLIENT_DEFAULT, ) -> Response: """ Send a request. @@ -866,9 +866,6 @@ class Client(BaseClient): raise RuntimeError("Cannot send a request, as the client has been closed.") self._state = ClientState.OPENED - timeout = ( - self.timeout if isinstance(timeout, UseClientDefault) else Timeout(timeout) - ) follow_redirects = ( self.follow_redirects if isinstance(follow_redirects, UseClientDefault) @@ -880,7 +877,6 @@ class Client(BaseClient): response = self._send_handling_auth( request, auth=auth, - timeout=timeout, follow_redirects=follow_redirects, history=[], ) @@ -898,7 +894,6 @@ class Client(BaseClient): self, request: Request, auth: Auth, - timeout: Timeout, follow_redirects: bool, history: typing.List[Response], ) -> Response: @@ -909,7 +904,6 @@ class Client(BaseClient): while True: response = self._send_handling_redirects( request, - timeout=timeout, follow_redirects=follow_redirects, history=history, ) @@ -933,7 +927,6 @@ class Client(BaseClient): def _send_handling_redirects( self, request: Request, - timeout: Timeout, follow_redirects: bool, history: typing.List[Response], ) -> Response: @@ -946,7 +939,7 @@ class Client(BaseClient): for hook in self._event_hooks["request"]: hook(request) - response = self._send_single_request(request, timeout) + response = self._send_single_request(request) try: for hook in self._event_hooks["response"]: hook(response) @@ -968,7 +961,7 @@ class Client(BaseClient): response.close() raise exc - def _send_single_request(self, request: Request, timeout: Timeout) -> Response: + def _send_single_request(self, request: Request) -> Response: """ Sends a single request, without handling any redirections. """ @@ -982,23 +975,14 @@ class Client(BaseClient): ) with request_context(request=request): - (status_code, headers, stream, extensions) = transport.handle_request( - request.method.encode(), - request.url.raw, - headers=request.headers.raw, - stream=request.stream, - extensions={"timeout": timeout.as_dict()}, - ) + response = transport.handle_request(request) - response = Response( - status_code, - headers=headers, - stream=stream, - extensions=extensions, - request=request, - ) + assert isinstance(response.stream, SyncByteStream) - response.stream = BoundSyncStream(stream, response=response, timer=timer) + response.request = request + response.stream = BoundSyncStream( + response.stream, response=response, timer=timer + ) self.cookies.extract_cookies(response) status = f"{response.status_code} {response.reason_phrase}" @@ -1494,9 +1478,10 @@ class AsyncClient(BaseClient): params=params, headers=headers, cookies=cookies, + timeout=timeout, ) response = await self.send( - request, auth=auth, follow_redirects=follow_redirects, timeout=timeout + request, auth=auth, follow_redirects=follow_redirects ) return response @@ -1537,12 +1522,12 @@ class AsyncClient(BaseClient): params=params, headers=headers, cookies=cookies, + timeout=timeout, ) response = await self.send( request=request, auth=auth, follow_redirects=follow_redirects, - timeout=timeout, stream=True, ) try: @@ -1557,7 +1542,6 @@ class AsyncClient(BaseClient): stream: bool = False, auth: typing.Union[AuthTypes, UseClientDefault] = USE_CLIENT_DEFAULT, follow_redirects: typing.Union[bool, UseClientDefault] = USE_CLIENT_DEFAULT, - timeout: typing.Union[TimeoutTypes, UseClientDefault] = USE_CLIENT_DEFAULT, ) -> Response: """ Send a request. @@ -1576,9 +1560,6 @@ class AsyncClient(BaseClient): raise RuntimeError("Cannot send a request, as the client has been closed.") self._state = ClientState.OPENED - timeout = ( - self.timeout if isinstance(timeout, UseClientDefault) else Timeout(timeout) - ) follow_redirects = ( self.follow_redirects if isinstance(follow_redirects, UseClientDefault) @@ -1590,7 +1571,6 @@ class AsyncClient(BaseClient): response = await self._send_handling_auth( request, auth=auth, - timeout=timeout, follow_redirects=follow_redirects, history=[], ) @@ -1608,7 +1588,6 @@ class AsyncClient(BaseClient): self, request: Request, auth: Auth, - timeout: Timeout, follow_redirects: bool, history: typing.List[Response], ) -> Response: @@ -1619,7 +1598,6 @@ class AsyncClient(BaseClient): while True: response = await self._send_handling_redirects( request, - timeout=timeout, follow_redirects=follow_redirects, history=history, ) @@ -1643,7 +1621,6 @@ class AsyncClient(BaseClient): async def _send_handling_redirects( self, request: Request, - timeout: Timeout, follow_redirects: bool, history: typing.List[Response], ) -> Response: @@ -1656,7 +1633,7 @@ class AsyncClient(BaseClient): for hook in self._event_hooks["request"]: await hook(request) - response = await self._send_single_request(request, timeout) + response = await self._send_single_request(request) try: for hook in self._event_hooks["response"]: await hook(response) @@ -1679,9 +1656,7 @@ class AsyncClient(BaseClient): await response.aclose() raise exc - async def _send_single_request( - self, request: Request, timeout: Timeout - ) -> Response: + async def _send_single_request(self, request: Request) -> Response: """ Sends a single request, without handling any redirections. """ @@ -1695,28 +1670,13 @@ class AsyncClient(BaseClient): ) with request_context(request=request): - ( - status_code, - headers, - stream, - extensions, - ) = await transport.handle_async_request( - request.method.encode(), - request.url.raw, - headers=request.headers.raw, - stream=request.stream, - extensions={"timeout": timeout.as_dict()}, - ) + response = await transport.handle_async_request(request) - response = Response( - status_code, - headers=headers, - stream=stream, - extensions=extensions, - request=request, + assert isinstance(response.stream, AsyncByteStream) + response.request = request + response.stream = BoundAsyncStream( + response.stream, response=response, timer=timer ) - - response.stream = BoundAsyncStream(stream, response=response, timer=timer) self.cookies.extract_cookies(response) status = f"{response.status_code} {response.reason_phrase}" diff --git a/httpx/_content.py b/httpx/_content.py index 86f3c7c2..d7e8aa09 100644 --- a/httpx/_content.py +++ b/httpx/_content.py @@ -15,8 +15,14 @@ from urllib.parse import urlencode from ._exceptions import StreamClosed, StreamConsumed from ._multipart import MultipartStream -from ._transports.base import AsyncByteStream, SyncByteStream -from ._types import RequestContent, RequestData, RequestFiles, ResponseContent +from ._types import ( + AsyncByteStream, + RequestContent, + RequestData, + RequestFiles, + ResponseContent, + SyncByteStream, +) from ._utils import peek_filelike_length, primitive_value_to_str diff --git a/httpx/_models.py b/httpx/_models.py index 0a54a6fa..7c6460e7 100644 --- a/httpx/_models.py +++ b/httpx/_models.py @@ -35,8 +35,8 @@ from ._exceptions import ( request_context, ) from ._status_codes import codes -from ._transports.base import AsyncByteStream, SyncByteStream from ._types import ( + AsyncByteStream, CookieTypes, HeaderTypes, PrimitiveData, @@ -46,6 +46,7 @@ from ._types import ( RequestData, RequestFiles, ResponseContent, + SyncByteStream, URLTypes, ) from ._utils import ( @@ -1081,15 +1082,19 @@ class Request: files: RequestFiles = None, json: typing.Any = None, stream: typing.Union[SyncByteStream, AsyncByteStream] = None, + extensions: dict = None, ): - if isinstance(method, bytes): - self.method = method.decode("ascii").upper() - else: - self.method = method.upper() + self.method = ( + method.decode("ascii").upper() + if isinstance(method, bytes) + else method.upper() + ) self.url = URL(url) if params is not None: self.url = self.url.copy_merge_params(params=params) self.headers = Headers(headers) + self.extensions = {} if extensions is None else extensions + if cookies: Cookies(cookies).set_cookie_header(self) diff --git a/httpx/_multipart.py b/httpx/_multipart.py index 683e6f13..4dfb838a 100644 --- a/httpx/_multipart.py +++ b/httpx/_multipart.py @@ -4,8 +4,13 @@ import os import typing from pathlib import Path -from ._transports.base import AsyncByteStream, SyncByteStream -from ._types import FileContent, FileTypes, RequestFiles +from ._types import ( + AsyncByteStream, + FileContent, + FileTypes, + RequestFiles, + SyncByteStream, +) from ._utils import ( format_form_param, guess_content_type, diff --git a/httpx/_transports/asgi.py b/httpx/_transports/asgi.py index 24c5452d..4e361658 100644 --- a/httpx/_transports/asgi.py +++ b/httpx/_transports/asgi.py @@ -1,9 +1,10 @@ import typing -from urllib.parse import unquote import sniffio -from .base import AsyncBaseTransport, AsyncByteStream +from .._models import Request, Response +from .._types import AsyncByteStream +from .base import AsyncBaseTransport if typing.TYPE_CHECKING: # pragma: no cover import asyncio @@ -79,34 +80,28 @@ class ASGITransport(AsyncBaseTransport): async def handle_async_request( self, - method: bytes, - url: typing.Tuple[bytes, bytes, typing.Optional[int], bytes], - headers: typing.List[typing.Tuple[bytes, bytes]], - stream: AsyncByteStream, - extensions: dict, - ) -> typing.Tuple[ - int, typing.List[typing.Tuple[bytes, bytes]], AsyncByteStream, dict - ]: + request: Request, + ) -> Response: + assert isinstance(request.stream, AsyncByteStream) + # ASGI scope. - scheme, host, port, full_path = url - path, _, query = full_path.partition(b"?") scope = { "type": "http", "asgi": {"version": "3.0"}, "http_version": "1.1", - "method": method.decode(), - "headers": [(k.lower(), v) for (k, v) in headers], - "scheme": scheme.decode("ascii"), - "path": unquote(path.decode("ascii")), - "raw_path": path, - "query_string": query, - "server": (host.decode("ascii"), port), + "method": request.method, + "headers": [(k.lower(), v) for (k, v) in request.headers.raw], + "scheme": request.url.scheme, + "path": request.url.path, + "raw_path": request.url.raw_path, + "query_string": request.url.query, + "server": (request.url.host, request.url.port), "client": self.client, "root_path": self.root_path, } # Request. - request_body_chunks = stream.__aiter__() + request_body_chunks = request.stream.__aiter__() request_complete = False # Response. @@ -147,7 +142,7 @@ class ASGITransport(AsyncBaseTransport): body = message.get("body", b"") more_body = message.get("more_body", False) - if body and method != b"HEAD": + if body and request.method != "HEAD": body_parts.append(body) if not more_body: @@ -164,6 +159,5 @@ class ASGITransport(AsyncBaseTransport): assert response_headers is not None stream = ASGIResponseStream(body_parts) - extensions = {} - return (status_code, response_headers, stream, extensions) + return Response(status_code, headers=response_headers, stream=stream) diff --git a/httpx/_transports/base.py b/httpx/_transports/base.py index eb519269..8c324ab4 100644 --- a/httpx/_transports/base.py +++ b/httpx/_transports/base.py @@ -1,67 +1,12 @@ import typing from types import TracebackType +from .._models import Request, Response + T = typing.TypeVar("T", bound="BaseTransport") A = typing.TypeVar("A", bound="AsyncBaseTransport") -class SyncByteStream: - def __iter__(self) -> typing.Iterator[bytes]: - raise NotImplementedError( - "The '__iter__' method must be implemented." - ) # pragma: nocover - yield b"" # pragma: nocover - - def close(self) -> None: - """ - Subclasses can override this method to release any network resources - after a request/response cycle is complete. - - Streaming cases should use a `try...finally` block to ensure that - the stream `close()` method is always called. - - Example: - - status_code, headers, stream, extensions = transport.handle_request(...) - try: - ... - finally: - stream.close() - """ - - def read(self) -> bytes: - """ - Simple cases can use `.read()` as a convience method for consuming - the entire stream and then closing it. - - Example: - - status_code, headers, stream, extensions = transport.handle_request(...) - body = stream.read() - """ - try: - return b"".join([part for part in self]) - finally: - self.close() - - -class AsyncByteStream: - async def __aiter__(self) -> typing.AsyncIterator[bytes]: - raise NotImplementedError( - "The '__aiter__' method must be implemented." - ) # pragma: nocover - yield b"" # pragma: nocover - - async def aclose(self) -> None: - pass - - async def aread(self) -> bytes: - try: - return b"".join([part async for part in self]) - finally: - await self.aclose() - - class BaseTransport: def __enter__(self: T) -> T: return self @@ -74,16 +19,7 @@ class BaseTransport: ) -> None: self.close() - def handle_request( - self, - method: bytes, - url: typing.Tuple[bytes, bytes, typing.Optional[int], bytes], - headers: typing.List[typing.Tuple[bytes, bytes]], - stream: SyncByteStream, - extensions: dict, - ) -> typing.Tuple[ - int, typing.List[typing.Tuple[bytes, bytes]], SyncByteStream, dict - ]: + def handle_request(self, request: Request) -> Response: """ Send a single HTTP request and return a response. @@ -167,14 +103,8 @@ class AsyncBaseTransport: async def handle_async_request( self, - method: bytes, - url: typing.Tuple[bytes, bytes, typing.Optional[int], bytes], - headers: typing.List[typing.Tuple[bytes, bytes]], - stream: AsyncByteStream, - extensions: dict, - ) -> typing.Tuple[ - int, typing.List[typing.Tuple[bytes, bytes]], AsyncByteStream, dict - ]: + request: Request, + ) -> Response: raise NotImplementedError( "The 'handle_async_request' method must be implemented." ) # pragma: nocover diff --git a/httpx/_transports/default.py b/httpx/_transports/default.py index 73401fce..2566a3f2 100644 --- a/httpx/_transports/default.py +++ b/httpx/_transports/default.py @@ -48,8 +48,9 @@ from .._exceptions import ( WriteError, WriteTimeout, ) -from .._types import CertTypes, VerifyTypes -from .base import AsyncBaseTransport, AsyncByteStream, BaseTransport, SyncByteStream +from .._models import Request, Response +from .._types import AsyncByteStream, CertTypes, SyncByteStream, VerifyTypes +from .base import AsyncBaseTransport, BaseTransport T = typing.TypeVar("T", bound="HTTPTransport") A = typing.TypeVar("A", bound="AsyncHTTPTransport") @@ -168,26 +169,24 @@ class HTTPTransport(BaseTransport): def handle_request( self, - method: bytes, - url: typing.Tuple[bytes, bytes, typing.Optional[int], bytes], - headers: typing.List[typing.Tuple[bytes, bytes]], - stream: SyncByteStream, - extensions: dict, - ) -> typing.Tuple[ - int, typing.List[typing.Tuple[bytes, bytes]], SyncByteStream, dict - ]: + request: Request, + ) -> Response: + assert isinstance(request.stream, SyncByteStream) + with map_httpcore_exceptions(): status_code, headers, byte_stream, extensions = self._pool.handle_request( - method=method, - url=url, - headers=headers, - stream=httpcore.IteratorByteStream(iter(stream)), - extensions=extensions, + method=request.method.encode("ascii"), + url=request.url.raw, + headers=request.headers.raw, + stream=httpcore.IteratorByteStream(iter(request.stream)), + extensions=request.extensions, ) stream = ResponseStream(byte_stream) - return status_code, headers, stream, extensions + return Response( + status_code, headers=headers, stream=stream, extensions=extensions + ) def close(self) -> None: self._pool.close() @@ -264,14 +263,10 @@ class AsyncHTTPTransport(AsyncBaseTransport): async def handle_async_request( self, - method: bytes, - url: typing.Tuple[bytes, bytes, typing.Optional[int], bytes], - headers: typing.List[typing.Tuple[bytes, bytes]], - stream: AsyncByteStream, - extensions: dict, - ) -> typing.Tuple[ - int, typing.List[typing.Tuple[bytes, bytes]], AsyncByteStream, dict - ]: + request: Request, + ) -> Response: + assert isinstance(request.stream, AsyncByteStream) + with map_httpcore_exceptions(): ( status_code, @@ -279,16 +274,18 @@ class AsyncHTTPTransport(AsyncBaseTransport): byte_stream, extensions, ) = await self._pool.handle_async_request( - method=method, - url=url, - headers=headers, - stream=httpcore.AsyncIteratorByteStream(stream.__aiter__()), - extensions=extensions, + method=request.method.encode("ascii"), + url=request.url.raw, + headers=request.headers.raw, + stream=httpcore.AsyncIteratorByteStream(request.stream.__aiter__()), + extensions=request.extensions, ) stream = AsyncResponseStream(byte_stream) - return status_code, headers, stream, extensions + return Response( + status_code, headers=headers, stream=stream, extensions=extensions + ) async def aclose(self) -> None: await self._pool.aclose() diff --git a/httpx/_transports/mock.py b/httpx/_transports/mock.py index 8d59b738..f61aee71 100644 --- a/httpx/_transports/mock.py +++ b/httpx/_transports/mock.py @@ -1,8 +1,8 @@ import asyncio import typing -from .._models import Request -from .base import AsyncBaseTransport, AsyncByteStream, BaseTransport, SyncByteStream +from .._models import Request, Response +from .base import AsyncBaseTransport, BaseTransport class MockTransport(AsyncBaseTransport, BaseTransport): @@ -11,47 +11,16 @@ class MockTransport(AsyncBaseTransport, BaseTransport): def handle_request( self, - method: bytes, - url: typing.Tuple[bytes, bytes, typing.Optional[int], bytes], - headers: typing.List[typing.Tuple[bytes, bytes]], - stream: SyncByteStream, - extensions: dict, - ) -> typing.Tuple[ - int, typing.List[typing.Tuple[bytes, bytes]], SyncByteStream, dict - ]: - request = Request( - method=method, - url=url, - headers=headers, - stream=stream, - ) + request: Request, + ) -> Response: request.read() - response = self.handler(request) - return ( - response.status_code, - response.headers.raw, - response.stream, - response.extensions, - ) + return self.handler(request) async def handle_async_request( self, - method: bytes, - url: typing.Tuple[bytes, bytes, typing.Optional[int], bytes], - headers: typing.List[typing.Tuple[bytes, bytes]], - stream: AsyncByteStream, - extensions: dict, - ) -> typing.Tuple[ - int, typing.List[typing.Tuple[bytes, bytes]], AsyncByteStream, dict - ]: - request = Request( - method=method, - url=url, - headers=headers, - stream=stream, - ) + request: Request, + ) -> Response: await request.aread() - response = self.handler(request) # Allow handler to *optionally* be an `async` function. @@ -62,9 +31,4 @@ class MockTransport(AsyncBaseTransport, BaseTransport): if asyncio.iscoroutine(response): response = await response - return ( - response.status_code, - response.headers.raw, - response.stream, - response.extensions, - ) + return response diff --git a/httpx/_transports/wsgi.py b/httpx/_transports/wsgi.py index e8bdfd3f..3dedf49f 100644 --- a/httpx/_transports/wsgi.py +++ b/httpx/_transports/wsgi.py @@ -2,9 +2,10 @@ import io import itertools import sys import typing -from urllib.parse import unquote -from .base import BaseTransport, SyncByteStream +from .._models import Request, Response +from .._types import SyncByteStream +from .base import BaseTransport def _skip_leading_empty_chunks(body: typing.Iterable) -> typing.Iterable: @@ -76,40 +77,28 @@ class WSGITransport(BaseTransport): self.remote_addr = remote_addr self.wsgi_errors = wsgi_errors - def handle_request( - self, - method: bytes, - url: typing.Tuple[bytes, bytes, typing.Optional[int], bytes], - headers: typing.List[typing.Tuple[bytes, bytes]], - stream: SyncByteStream, - extensions: dict, - ) -> typing.Tuple[ - int, typing.List[typing.Tuple[bytes, bytes]], SyncByteStream, dict - ]: - wsgi_input = io.BytesIO(b"".join(stream)) - - scheme, host, port, full_path = url - path, _, query = full_path.partition(b"?") - if port is None: - port = {b"http": 80, b"https": 443}[scheme] + def handle_request(self, request: Request) -> Response: + request.read() + wsgi_input = io.BytesIO(request.content) + port = request.url.port or {"http": 80, "https": 443}[request.url.scheme] environ = { "wsgi.version": (1, 0), - "wsgi.url_scheme": scheme.decode("ascii"), + "wsgi.url_scheme": request.url.scheme, "wsgi.input": wsgi_input, "wsgi.errors": self.wsgi_errors or sys.stderr, "wsgi.multithread": True, "wsgi.multiprocess": False, "wsgi.run_once": False, - "REQUEST_METHOD": method.decode(), + "REQUEST_METHOD": request.method, "SCRIPT_NAME": self.script_name, - "PATH_INFO": unquote(path.decode("ascii")), - "QUERY_STRING": query.decode("ascii"), - "SERVER_NAME": host.decode("ascii"), + "PATH_INFO": request.url.path, + "QUERY_STRING": request.url.query.decode("ascii"), + "SERVER_NAME": request.url.host, "SERVER_PORT": str(port), "REMOTE_ADDR": self.remote_addr, } - for header_key, header_value in headers: + for header_key, header_value in request.headers.raw: key = header_key.decode("ascii").upper().replace("-", "_") if key not in ("CONTENT_TYPE", "CONTENT_LENGTH"): key = "HTTP_" + key @@ -141,6 +130,5 @@ class WSGITransport(BaseTransport): (key.encode("ascii"), value.encode("ascii")) for key, value in seen_response_headers ] - extensions = {} - return (status_code, headers, stream, extensions) + return Response(status_code, headers=headers, stream=stream) diff --git a/httpx/_types.py b/httpx/_types.py index 2381996c..71a97a26 100644 --- a/httpx/_types.py +++ b/httpx/_types.py @@ -8,9 +8,11 @@ from typing import ( IO, TYPE_CHECKING, AsyncIterable, + AsyncIterator, Callable, Dict, Iterable, + Iterator, List, Mapping, Optional, @@ -89,3 +91,60 @@ FileTypes = Union[ Tuple[Optional[str], FileContent, Optional[str]], ] RequestFiles = Union[Mapping[str, FileTypes], Sequence[Tuple[str, FileTypes]]] + + +class SyncByteStream: + def __iter__(self) -> Iterator[bytes]: + raise NotImplementedError( + "The '__iter__' method must be implemented." + ) # pragma: nocover + yield b"" # pragma: nocover + + def close(self) -> None: + """ + Subclasses can override this method to release any network resources + after a request/response cycle is complete. + + Streaming cases should use a `try...finally` block to ensure that + the stream `close()` method is always called. + + Example: + + status_code, headers, stream, extensions = transport.handle_request(...) + try: + ... + finally: + stream.close() + """ + + def read(self) -> bytes: + """ + Simple cases can use `.read()` as a convience method for consuming + the entire stream and then closing it. + + Example: + + status_code, headers, stream, extensions = transport.handle_request(...) + body = stream.read() + """ + try: + return b"".join([part for part in self]) + finally: + self.close() + + +class AsyncByteStream: + async def __aiter__(self) -> AsyncIterator[bytes]: + raise NotImplementedError( + "The '__aiter__' method must be implemented." + ) # pragma: nocover + yield b"" # pragma: nocover + + async def aclose(self) -> None: + pass + + async def aread(self) -> bytes: + try: + return b"".join([part async for part in self]) + finally: + await self.aclose() diff --git a/tests/client/test_auth.py b/tests/client/test_auth.py index b6cb42d0..8caaeb5d 100644 --- a/tests/client/test_auth.py +++ b/tests/client/test_auth.py @@ -619,6 +619,13 @@ def test_sync_auth_history() -> None: assert len(resp1.history) == 0 +class ConsumeBodyTransport(httpx.MockTransport): + async def handle_async_request(self, request: Request) -> Response: + assert isinstance(request.stream, httpx.AsyncByteStream) + [_ async for _ in request.stream] + return self.handler(request) + + @pytest.mark.asyncio async def test_digest_auth_unavailable_streaming_body(): url = "https://example.org/" @@ -628,7 +635,7 @@ async def test_digest_auth_unavailable_streaming_body(): async def streaming_body(): yield b"Example request body" # pragma: nocover - async with httpx.AsyncClient(transport=httpx.MockTransport(app)) as client: + async with httpx.AsyncClient(transport=ConsumeBodyTransport(app)) as client: with pytest.raises(httpx.StreamConsumed): await client.post(url, content=streaming_body(), auth=auth) diff --git a/tests/client/test_redirects.py b/tests/client/test_redirects.py index 87d9cdfa..adc3aae3 100644 --- a/tests/client/test_redirects.py +++ b/tests/client/test_redirects.py @@ -317,13 +317,20 @@ def test_can_stream_if_no_redirect(): client = httpx.Client(transport=httpx.MockTransport(redirects)) url = "https://example.org/redirect_301" with client.stream("GET", url, follow_redirects=False) as response: - assert not response.is_closed + pass assert response.status_code == httpx.codes.MOVED_PERMANENTLY assert response.headers["location"] == "https://example.org/" +class ConsumeBodyTransport(httpx.MockTransport): + def handle_request(self, request: httpx.Request) -> httpx.Response: + assert isinstance(request.stream, httpx.SyncByteStream) + [_ for _ in request.stream] + return self.handler(request) + + def test_cannot_redirect_streaming_body(): - client = httpx.Client(transport=httpx.MockTransport(redirects)) + client = httpx.Client(transport=ConsumeBodyTransport(redirects)) url = "https://example.org/redirect_body" def streaming_body(): diff --git a/tests/test_asgi.py b/tests/test_asgi.py index d7cf9412..60f55dfd 100644 --- a/tests/test_asgi.py +++ b/tests/test_asgi.py @@ -70,40 +70,24 @@ async def raise_exc_after_response(scope, receive, send): raise RuntimeError() -async def empty_stream(): - yield b"" - - @pytest.mark.usefixtures("async_environment") async def test_asgi_transport(): async with httpx.ASGITransport(app=hello_world) as transport: - status_code, headers, stream, ext = await transport.handle_async_request( - method=b"GET", - url=(b"http", b"www.example.org", 80, b"/"), - headers=[(b"Host", b"www.example.org")], - stream=empty_stream(), - extensions={}, - ) - body = b"".join([part async for part in stream]) - - assert status_code == 200 - assert body == b"Hello, World!" + request = httpx.Request("GET", "http://www.example.com/") + response = await transport.handle_async_request(request) + await response.aread() + assert response.status_code == 200 + assert response.content == b"Hello, World!" @pytest.mark.usefixtures("async_environment") async def test_asgi_transport_no_body(): async with httpx.ASGITransport(app=echo_body) as transport: - status_code, headers, stream, ext = await transport.handle_async_request( - method=b"GET", - url=(b"http", b"www.example.org", 80, b"/"), - headers=[(b"Host", b"www.example.org")], - stream=empty_stream(), - extensions={}, - ) - body = b"".join([part async for part in stream]) - - assert status_code == 200 - assert body == b"" + request = httpx.Request("GET", "http://www.example.com/") + response = await transport.handle_async_request(request) + await response.aread() + assert response.status_code == 200 + assert response.content == b"" @pytest.mark.usefixtures("async_environment")