]> git.ipfire.org Git - thirdparty/starlette.git/commitdiff
Replace HTTP client on TestClient from `requests` to `httpx` (#1376)
authorMarcelo Trylesinski <marcelotryle@gmail.com>
Tue, 6 Sep 2022 05:43:32 +0000 (07:43 +0200)
committerGitHub <noreply@github.com>
Tue, 6 Sep 2022 05:43:32 +0000 (07:43 +0200)
13 files changed:
README.md
docs/index.md
docs/testclient.md
pyproject.toml
requirements.txt
setup.cfg
starlette/testclient.py
tests/middleware/test_cors.py
tests/middleware/test_session.py
tests/test_formparsers.py
tests/test_requests.py
tests/test_responses.py
tests/test_staticfiles.py

index 44bd55c77e82a6d833c11b42e4645fb1739532e2..50e0dd63badc7769d42e6eadc1db942ceb667d5f 100644 (file)
--- a/README.md
+++ b/README.md
@@ -30,7 +30,7 @@ It is production-ready, and gives you the following:
 * WebSocket support.
 * In-process background tasks.
 * Startup and shutdown events.
-* Test client built on `requests`.
+* Test client built on `httpx`.
 * CORS, GZip, Static Files, Streaming responses.
 * Session and Cookie support.
 * 100% test coverage.
@@ -87,7 +87,7 @@ For a more complete example, see [encode/starlette-example](https://github.com/e
 
 Starlette only requires `anyio`, and the following are optional:
 
-* [`requests`][requests] - Required if you want to use the `TestClient`.
+* [`httpx`][httpx] - Required if you want to use the `TestClient`.
 * [`jinja2`][jinja2] - Required if you want to use `Jinja2Templates`.
 * [`python-multipart`][python-multipart] - Required if you want to support form parsing, with `request.form()`.
 * [`itsdangerous`][itsdangerous] - Required for `SessionMiddleware` support.
@@ -134,7 +134,7 @@ in isolation.
 <p align="center"><i>Starlette is <a href="https://github.com/encode/starlette/blob/master/LICENSE.md">BSD licensed</a> code.<br/>Designed & crafted with care.</i></br>&mdash; ⭐️ &mdash;</p>
 
 [asgi]: https://asgi.readthedocs.io/en/latest/
-[requests]: http://docs.python-requests.org/en/master/
+[httpx]: https://www.python-httpx.org/
 [jinja2]: http://jinja.pocoo.org/
 [python-multipart]: https://andrew-d.github.io/python-multipart/
 [itsdangerous]: https://pythonhosted.org/itsdangerous/
index 5a501402133ae4822e273b0e5d89f81be543711a..9f19778751462dfc8092d767cccef8e3cf4441c5 100644 (file)
@@ -27,7 +27,7 @@ It is production-ready, and gives you the following:
 * WebSocket support.
 * In-process background tasks.
 * Startup and shutdown events.
-* Test client built on `requests`.
+* Test client built on `httpx`.
 * CORS, GZip, Static Files, Streaming responses.
 * Session and Cookie support.
 * 100% test coverage.
@@ -83,7 +83,7 @@ For a more complete example, [see here](https://github.com/encode/starlette-exam
 
 Starlette only requires `anyio`, and the following dependencies are optional:
 
-* [`requests`][requests] - Required if you want to use the `TestClient`.
+* [`httpx`][httpx] - Required if you want to use the `TestClient`.
 * [`jinja2`][jinja2] - Required if you want to use `Jinja2Templates`.
 * [`python-multipart`][python-multipart] - Required if you want to support form parsing, with `request.form()`.
 * [`itsdangerous`][itsdangerous] - Required for `SessionMiddleware` support.
@@ -130,7 +130,7 @@ in isolation.
 <p align="center"><i>Starlette is <a href="https://github.com/encode/starlette/blob/master/LICENSE.md">BSD licensed</a> code.<br/>Designed & crafted with care.</i></br>&mdash; ⭐️ &mdash;</p>
 
 [asgi]: https://asgi.readthedocs.io/en/latest/
-[requests]: http://docs.python-requests.org/en/master/
+[httpx]: https://www.python-httpx.org/
 [jinja2]: http://jinja.pocoo.org/
 [python-multipart]: https://andrew-d.github.io/python-multipart/
 [itsdangerous]: https://pythonhosted.org/itsdangerous/
index f64c570bb7b117a9eecb78b9ac934122a8143936..053b420055f6ce0eaa43b132a640c4e38cbf791b 100644 (file)
@@ -1,6 +1,6 @@
 
 The test client allows you to make requests against your ASGI application,
-using the `requests` library.
+using the `httpx` library.
 
 ```python
 from starlette.responses import HTMLResponse
@@ -19,11 +19,11 @@ def test_app():
     assert response.status_code == 200
 ```
 
-The test client exposes the same interface as any other `requests` session.
+The test client exposes the same interface as any other `httpx` session.
 In particular, note that the calls to make a request are just standard
 function calls, not awaitables.
 
-You can use any of `requests` standard API, such as authentication, session
+You can use any of `httpx` standard API, such as authentication, session
 cookies handling, or file uploads.
 
 For example, to set headers on the TestClient you can do:
@@ -96,7 +96,7 @@ def test_app()
 
 You can also test websocket sessions with the test client.
 
-The `requests` library will be used to build the initial handshake, meaning you
+The `httpx` library will be used to build the initial handshake, meaning you
 can use the same authentication options and other headers between both http and
 websocket testing.
 
@@ -129,7 +129,7 @@ always raised by the test client.
 
 #### Establishing a test session
 
-* `.websocket_connect(url, subprotocols=None, **options)` - Takes the same set of arguments as `requests.get()`.
+* `.websocket_connect(url, subprotocols=None, **options)` - Takes the same set of arguments as `httpx.get()`.
 
 May raise `starlette.websockets.WebSocketDisconnect` if the application does not accept the websocket connection.
 
index 7bbce89d939261d9de3048ecfac3387471451785..f994fc361367b18f4a85b2d47a03c05bd4b7d3f7 100644 (file)
@@ -37,7 +37,7 @@ full = [
     "jinja2",
     "python-multipart",
     "pyyaml",
-    "requests",
+    "httpx>=0.22.0",
 ]
 
 [project.urls]
index 648f0fa0184adef6230a29ede72a6dd94065b270..0b54fa596f40944cfbc46a47d6b7866b033c5356 100644 (file)
@@ -10,7 +10,6 @@ flake8==3.9.2
 isort==5.10.1
 mypy==0.971
 typing_extensions==4.3.0
-types-requests==2.26.3
 types-contextvars==2.4.7
 types-PyYAML==6.0.11
 types-dataclasses==0.6.6
index 93f27e4e08c33db147df57f3ffc2bc2ab2804e42..23cf32cc0318e3b8e87be9cf508928dfdefd0502 100644 (file)
--- a/setup.cfg
+++ b/setup.cfg
@@ -32,6 +32,9 @@ filterwarnings=
     ignore: starlette\.middleware\.wsgi is deprecated and will be removed in a future release\.*:DeprecationWarning
     ignore: Async generator 'starlette\.requests\.Request\.stream' was garbage collected before it had been exhausted.*:ResourceWarning
     ignore: path is deprecated.*:DeprecationWarning:certifi
+    ignore: Use 'content=<...>' to upload raw bytes/text content.:DeprecationWarning
+    ignore: The `allow_redirects` argument is deprecated. Use `follow_redirects` instead.:DeprecationWarning
+    ignore: 'cgi' is deprecated and slated for removal in Python 3\.13:DeprecationWarning
 
 [coverage:run]
 source_pkgs = starlette, tests
index efe2b493bb0dd42f3a3262a5eb1fc3d450a5fb51..455440ce596d21d28ef3e9bec3c3d1534b878586 100644 (file)
@@ -1,22 +1,22 @@
 import contextlib
-import http
 import inspect
 import io
 import json
 import math
 import queue
 import sys
-import types
 import typing
+import warnings
 from concurrent.futures import Future
-from urllib.parse import unquote, urljoin, urlsplit
+from types import GeneratorType
+from urllib.parse import unquote, urljoin
 
-import anyio.abc
-import requests
+import anyio
+import httpx
 from anyio.streams.stapled import StapledObjectStream
 
 from starlette._utils import is_async_callable
-from starlette.types import Message, Receive, Scope, Send
+from starlette.types import ASGIApp, Message, Receive, Scope, Send
 from starlette.websockets import WebSocketDisconnect
 
 if sys.version_info >= (3, 8):  # pragma: no cover
@@ -24,63 +24,15 @@ if sys.version_info >= (3, 8):  # pragma: no cover
 else:  # pragma: no cover
     from typing_extensions import TypedDict
 
-
 _PortalFactoryType = typing.Callable[
     [], typing.ContextManager[anyio.abc.BlockingPortal]
 ]
 
-
-# Annotations for `Session.request()`
-Cookies = typing.Union[
-    typing.MutableMapping[str, str], requests.cookies.RequestsCookieJar
-]
-Params = typing.Union[bytes, typing.MutableMapping[str, str]]
-DataType = typing.Union[bytes, typing.MutableMapping[str, str], typing.IO]
-TimeOut = typing.Union[float, typing.Tuple[float, float]]
-FileType = typing.MutableMapping[str, typing.IO]
-AuthType = typing.Union[
-    typing.Tuple[str, str],
-    requests.auth.AuthBase,
-    typing.Callable[[requests.PreparedRequest], requests.PreparedRequest],
-]
-
-
 ASGIInstance = typing.Callable[[Receive, Send], typing.Awaitable[None]]
 ASGI2App = typing.Callable[[Scope], ASGIInstance]
 ASGI3App = typing.Callable[[Scope, Receive, Send], typing.Awaitable[None]]
 
 
-class _HeaderDict(requests.packages.urllib3._collections.HTTPHeaderDict):
-    def get_all(self, key: str, default: str) -> str:
-        return self.getheaders(key)
-
-
-class _MockOriginalResponse:
-    """
-    We have to jump through some hoops to present the response as if
-    it was made using urllib3.
-    """
-
-    def __init__(self, headers: typing.List[typing.Tuple[bytes, bytes]]) -> None:
-        self.msg = _HeaderDict(headers)
-        self.closed = False
-
-    def isclosed(self) -> bool:
-        return self.closed
-
-
-class _Upgrade(Exception):
-    def __init__(self, session: "WebSocketTestSession") -> None:
-        self.session = session
-
-
-def _get_reason_phrase(status_code: int) -> str:
-    try:
-        return http.HTTPStatus(status_code).phrase
-    except ValueError:
-        return ""
-
-
 def _is_asgi3(app: typing.Union[ASGI2App, ASGI3App]) -> bool:
     if inspect.isclass(app):
         return hasattr(app, "__await__")
@@ -105,7 +57,127 @@ class _AsyncBackend(TypedDict):
     backend_options: typing.Dict[str, typing.Any]
 
 
-class _ASGIAdapter(requests.adapters.HTTPAdapter):
+class _Upgrade(Exception):
+    def __init__(self, session: "WebSocketTestSession") -> None:
+        self.session = session
+
+
+class WebSocketTestSession:
+    def __init__(
+        self,
+        app: ASGI3App,
+        scope: Scope,
+        portal_factory: _PortalFactoryType,
+    ) -> None:
+        self.app = app
+        self.scope = scope
+        self.accepted_subprotocol = None
+        self.portal_factory = portal_factory
+        self._receive_queue: "queue.Queue[typing.Any]" = queue.Queue()
+        self._send_queue: "queue.Queue[typing.Any]" = queue.Queue()
+        self.extra_headers = None
+
+    def __enter__(self) -> "WebSocketTestSession":
+        self.exit_stack = contextlib.ExitStack()
+        self.portal = self.exit_stack.enter_context(self.portal_factory())
+
+        try:
+            _: "Future[None]" = self.portal.start_task_soon(self._run)
+            self.send({"type": "websocket.connect"})
+            message = self.receive()
+            self._raise_on_close(message)
+        except Exception:
+            self.exit_stack.close()
+            raise
+        self.accepted_subprotocol = message.get("subprotocol", None)
+        self.extra_headers = message.get("headers", None)
+        return self
+
+    def __exit__(self, *args: typing.Any) -> None:
+        try:
+            self.close(1000)
+        finally:
+            self.exit_stack.close()
+        while not self._send_queue.empty():
+            message = self._send_queue.get()
+            if isinstance(message, BaseException):
+                raise message
+
+    async def _run(self) -> None:
+        """
+        The sub-thread in which the websocket session runs.
+        """
+        scope = self.scope
+        receive = self._asgi_receive
+        send = self._asgi_send
+        try:
+            await self.app(scope, receive, send)
+        except BaseException as exc:
+            self._send_queue.put(exc)
+            raise
+
+    async def _asgi_receive(self) -> Message:
+        while self._receive_queue.empty():
+            await anyio.sleep(0)
+        return self._receive_queue.get()
+
+    async def _asgi_send(self, message: Message) -> None:
+        self._send_queue.put(message)
+
+    def _raise_on_close(self, message: Message) -> None:
+        if message["type"] == "websocket.close":
+            raise WebSocketDisconnect(
+                message.get("code", 1000), message.get("reason", "")
+            )
+
+    def send(self, message: Message) -> None:
+        self._receive_queue.put(message)
+
+    def send_text(self, data: str) -> None:
+        self.send({"type": "websocket.receive", "text": data})
+
+    def send_bytes(self, data: bytes) -> None:
+        self.send({"type": "websocket.receive", "bytes": data})
+
+    def send_json(self, data: typing.Any, mode: str = "text") -> None:
+        assert mode in ["text", "binary"]
+        text = json.dumps(data)
+        if mode == "text":
+            self.send({"type": "websocket.receive", "text": text})
+        else:
+            self.send({"type": "websocket.receive", "bytes": text.encode("utf-8")})
+
+    def close(self, code: int = 1000) -> None:
+        self.send({"type": "websocket.disconnect", "code": code})
+
+    def receive(self) -> Message:
+        message = self._send_queue.get()
+        if isinstance(message, BaseException):
+            raise message
+        return message
+
+    def receive_text(self) -> str:
+        message = self.receive()
+        self._raise_on_close(message)
+        return message["text"]
+
+    def receive_bytes(self) -> bytes:
+        message = self.receive()
+        self._raise_on_close(message)
+        return message["bytes"]
+
+    def receive_json(self, mode: str = "text") -> typing.Any:
+        assert mode in ["text", "binary"]
+        message = self.receive()
+        self._raise_on_close(message)
+        if mode == "text":
+            text = message["text"]
+        else:
+            text = message["bytes"].decode("utf-8")
+        return json.loads(text)
+
+
+class _TestClientTransport(httpx.BaseTransport):
     def __init__(
         self,
         app: ASGI3App,
@@ -118,12 +190,12 @@ class _ASGIAdapter(requests.adapters.HTTPAdapter):
         self.root_path = root_path
         self.portal_factory = portal_factory
 
-    def send(
-        self, request: requests.PreparedRequest, *args: typing.Any, **kwargs: typing.Any
-    ) -> requests.Response:
-        scheme, netloc, path, query, fragment = (
-            str(item) for item in urlsplit(request.url)
-        )
+    def handle_request(self, request: httpx.Request) -> httpx.Response:
+        scheme = request.url.scheme
+        netloc = unquote(request.url.netloc.decode(encoding="ascii"))
+        path = request.url.path
+        raw_path = request.url.raw_path
+        query = unquote(request.url.query.decode(encoding="ascii"))
 
         default_port = {"http": 80, "ws": 80, "https": 443, "wss": 443}[scheme]
 
@@ -137,9 +209,9 @@ class _ASGIAdapter(requests.adapters.HTTPAdapter):
         # Include the 'host' header.
         if "host" in request.headers:
             headers: typing.List[typing.Tuple[bytes, bytes]] = []
-        elif port == default_port:
+        elif port == default_port:  # pragma: no cover
             headers = [(b"host", host.encode())]
-        else:
+        else:  # pragma: no cover
             headers = [(b"host", (f"{host}:{port}").encode())]
 
         # Include other request headers.
@@ -159,7 +231,7 @@ class _ASGIAdapter(requests.adapters.HTTPAdapter):
             scope = {
                 "type": "websocket",
                 "path": unquote(path),
-                "raw_path": path.encode(),
+                "raw_path": raw_path,
                 "root_path": self.root_path,
                 "scheme": scheme,
                 "query_string": query.encode(),
@@ -176,7 +248,7 @@ class _ASGIAdapter(requests.adapters.HTTPAdapter):
             "http_version": "1.1",
             "method": request.method,
             "path": unquote(path),
-            "raw_path": path.encode(),
+            "raw_path": raw_path,
             "root_path": self.root_path,
             "scheme": scheme,
             "query_string": query.encode(),
@@ -189,7 +261,7 @@ class _ASGIAdapter(requests.adapters.HTTPAdapter):
         request_complete = False
         response_started = False
         response_complete: anyio.Event
-        raw_kwargs: typing.Dict[str, typing.Any] = {"body": io.BytesIO()}
+        raw_kwargs: typing.Dict[str, typing.Any] = {"stream": io.BytesIO()}
         template = None
         context = None
 
@@ -201,18 +273,18 @@ class _ASGIAdapter(requests.adapters.HTTPAdapter):
                     await response_complete.wait()
                 return {"type": "http.disconnect"}
 
-            body = request.body
+            body = request.read()
             if isinstance(body, str):
-                body_bytes: bytes = body.encode("utf-8")
+                body_bytes: bytes = body.encode("utf-8")  # pragma: no cover
             elif body is None:
-                body_bytes = b""
-            elif isinstance(body, types.GeneratorType):
-                try:
+                body_bytes = b""  # pragma: no cover
+            elif isinstance(body, GeneratorType):
+                try:  # pragma: no cover
                     chunk = body.send(None)
                     if isinstance(chunk, str):
                         chunk = chunk.encode("utf-8")
                     return {"type": "http.request", "body": chunk, "more_body": True}
-                except StopIteration:
+                except StopIteration:  # pragma: no cover
                     request_complete = True
                     return {"type": "http.request", "body": b""}
             else:
@@ -228,17 +300,11 @@ class _ASGIAdapter(requests.adapters.HTTPAdapter):
                 assert (
                     not response_started
                 ), 'Received multiple "http.response.start" messages.'
-                raw_kwargs["version"] = 11
-                raw_kwargs["status"] = message["status"]
-                raw_kwargs["reason"] = _get_reason_phrase(message["status"])
+                raw_kwargs["status_code"] = message["status"]
                 raw_kwargs["headers"] = [
                     (key.decode(), value.decode())
                     for key, value in message.get("headers", [])
                 ]
-                raw_kwargs["preload_content"] = False
-                raw_kwargs["original_response"] = _MockOriginalResponse(
-                    raw_kwargs["headers"]
-                )
                 response_started = True
             elif message["type"] == "http.response.body":
                 assert (
@@ -250,9 +316,9 @@ class _ASGIAdapter(requests.adapters.HTTPAdapter):
                 body = message.get("body", b"")
                 more_body = message.get("more_body", False)
                 if request.method != "HEAD":
-                    raw_kwargs["body"].write(body)
+                    raw_kwargs["stream"].write(body)
                 if not more_body:
-                    raw_kwargs["body"].seek(0)
+                    raw_kwargs["stream"].seek(0)
                     response_complete.set()
             elif message["type"] == "http.response.template":
                 template = message["template"]
@@ -270,153 +336,35 @@ class _ASGIAdapter(requests.adapters.HTTPAdapter):
             assert response_started, "TestClient did not receive any response."
         elif not response_started:
             raw_kwargs = {
-                "version": 11,
-                "status": 500,
-                "reason": "Internal Server Error",
+                "status_code": 500,
                 "headers": [],
-                "preload_content": False,
-                "original_response": _MockOriginalResponse([]),
-                "body": io.BytesIO(),
+                "stream": io.BytesIO(),
             }
 
-        raw = requests.packages.urllib3.HTTPResponse(**raw_kwargs)
-        response = self.build_response(request, raw)
+        raw_kwargs["stream"] = httpx.ByteStream(raw_kwargs["stream"].read())
+
+        response = httpx.Response(**raw_kwargs, request=request)
         if template is not None:
-            response.template = template
-            response.context = context
+            response.template = template  # type: ignore[attr-defined]
+            response.context = context  # type: ignore[attr-defined]
         return response
 
 
-class WebSocketTestSession:
-    def __init__(
-        self,
-        app: ASGI3App,
-        scope: Scope,
-        portal_factory: _PortalFactoryType,
-    ) -> None:
-        self.app = app
-        self.scope = scope
-        self.accepted_subprotocol = None
-        self.extra_headers = None
-        self.portal_factory = portal_factory
-        self._receive_queue: "queue.Queue[typing.Any]" = queue.Queue()
-        self._send_queue: "queue.Queue[typing.Any]" = queue.Queue()
-
-    def __enter__(self) -> "WebSocketTestSession":
-        self.exit_stack = contextlib.ExitStack()
-        self.portal = self.exit_stack.enter_context(self.portal_factory())
-
-        try:
-            _: "Future[None]" = self.portal.start_task_soon(self._run)
-            self.send({"type": "websocket.connect"})
-            message = self.receive()
-            self._raise_on_close(message)
-        except Exception:
-            self.exit_stack.close()
-            raise
-        self.accepted_subprotocol = message.get("subprotocol", None)
-        self.extra_headers = message.get("headers", None)
-        return self
-
-    def __exit__(self, *args: typing.Any) -> None:
-        try:
-            self.close(1000)
-        finally:
-            self.exit_stack.close()
-        while not self._send_queue.empty():
-            message = self._send_queue.get()
-            if isinstance(message, BaseException):
-                raise message
-
-    async def _run(self) -> None:
-        """
-        The sub-thread in which the websocket session runs.
-        """
-        scope = self.scope
-        receive = self._asgi_receive
-        send = self._asgi_send
-        try:
-            await self.app(scope, receive, send)
-        except BaseException as exc:
-            self._send_queue.put(exc)
-            raise
-
-    async def _asgi_receive(self) -> Message:
-        while self._receive_queue.empty():
-            await anyio.sleep(0)
-        return self._receive_queue.get()
-
-    async def _asgi_send(self, message: Message) -> None:
-        self._send_queue.put(message)
-
-    def _raise_on_close(self, message: Message) -> None:
-        if message["type"] == "websocket.close":
-            raise WebSocketDisconnect(
-                message.get("code", 1000), message.get("reason", "")
-            )
-
-    def send(self, message: Message) -> None:
-        self._receive_queue.put(message)
-
-    def send_text(self, data: str) -> None:
-        self.send({"type": "websocket.receive", "text": data})
-
-    def send_bytes(self, data: bytes) -> None:
-        self.send({"type": "websocket.receive", "bytes": data})
-
-    def send_json(self, data: typing.Any, mode: str = "text") -> None:
-        assert mode in ["text", "binary"]
-        text = json.dumps(data)
-        if mode == "text":
-            self.send({"type": "websocket.receive", "text": text})
-        else:
-            self.send({"type": "websocket.receive", "bytes": text.encode("utf-8")})
-
-    def close(self, code: int = 1000) -> None:
-        self.send({"type": "websocket.disconnect", "code": code})
-
-    def receive(self) -> Message:
-        message = self._send_queue.get()
-        if isinstance(message, BaseException):
-            raise message
-        return message
-
-    def receive_text(self) -> str:
-        message = self.receive()
-        self._raise_on_close(message)
-        return message["text"]
-
-    def receive_bytes(self) -> bytes:
-        message = self.receive()
-        self._raise_on_close(message)
-        return message["bytes"]
-
-    def receive_json(self, mode: str = "text") -> typing.Any:
-        assert mode in ["text", "binary"]
-        message = self.receive()
-        self._raise_on_close(message)
-        if mode == "text":
-            text = message["text"]
-        else:
-            text = message["bytes"].decode("utf-8")
-        return json.loads(text)
-
-
-class TestClient(requests.Session):
-    __test__ = False  # For pytest to not discover this up.
+class TestClient(httpx.Client):
+    __test__ = False
     task: "Future[None]"
     portal: typing.Optional[anyio.abc.BlockingPortal] = None
 
     def __init__(
         self,
-        app: typing.Union[ASGI2App, ASGI3App],
+        app: ASGIApp,
         base_url: str = "http://testserver",
         raise_server_exceptions: bool = True,
         root_path: str = "",
         backend: str = "asyncio",
         backend_options: typing.Optional[typing.Dict[str, typing.Any]] = None,
+        cookies: httpx._client.CookieTypes = None,
     ) -> None:
-        super().__init__()
         self.async_backend = _AsyncBackend(
             backend=backend, backend_options=backend_options or {}
         )
@@ -424,69 +372,320 @@ class TestClient(requests.Session):
             app = typing.cast(ASGI3App, app)
             asgi_app = app
         else:
-            app = typing.cast(ASGI2App, app)
-            asgi_app = _WrapASGI2(app)  #  type: ignore
-        adapter = _ASGIAdapter(
-            asgi_app,
+            app = typing.cast(ASGI2App, app)  # type: ignore[assignment]
+            asgi_app = _WrapASGI2(app)  # type: ignore[arg-type]
+        self.app = asgi_app
+        transport = _TestClientTransport(
+            self.app,
             portal_factory=self._portal_factory,
             raise_server_exceptions=raise_server_exceptions,
             root_path=root_path,
         )
-        self.mount("http://", adapter)
-        self.mount("https://", adapter)
-        self.mount("ws://", adapter)
-        self.mount("wss://", adapter)
-        self.headers.update({"user-agent": "testclient"})
-        self.app = asgi_app
-        self.base_url = base_url
+        super().__init__(
+            app=self.app,
+            base_url=base_url,
+            headers={"user-agent": "testclient"},
+            transport=transport,
+            follow_redirects=True,
+            cookies=cookies,
+        )
 
     @contextlib.contextmanager
-    def _portal_factory(
-        self,
-    ) -> typing.Generator[anyio.abc.BlockingPortal, None, None]:
+    def _portal_factory(self) -> typing.Generator[anyio.abc.BlockingPortal, None, None]:
         if self.portal is not None:
             yield self.portal
         else:
             with anyio.start_blocking_portal(**self.async_backend) as portal:
                 yield portal
 
-    def request(  # type: ignore
+    def _choose_redirect_arg(
+        self,
+        follow_redirects: typing.Optional[bool],
+        allow_redirects: typing.Optional[bool],
+    ) -> typing.Union[bool, httpx._client.UseClientDefault]:
+        redirect: typing.Union[
+            bool, httpx._client.UseClientDefault
+        ] = httpx._client.USE_CLIENT_DEFAULT
+        if allow_redirects is not None:
+            message = (
+                "The `allow_redirects` argument is deprecated. "
+                "Use `follow_redirects` instead."
+            )
+            warnings.warn(message, DeprecationWarning)
+            redirect = allow_redirects
+        if follow_redirects is not None:
+            redirect = follow_redirects
+        elif allow_redirects is not None and follow_redirects is not None:
+            raise RuntimeError(  # pragma: no cover
+                "Cannot use both `allow_redirects` and `follow_redirects`."
+            )
+        return redirect
+
+    def request(  # type: ignore[override]
         self,
         method: str,
-        url: str,
-        params: Params = None,
-        data: DataType = None,
-        headers: typing.MutableMapping[str, str] = None,
-        cookies: Cookies = None,
-        files: FileType = None,
-        auth: AuthType = None,
-        timeout: TimeOut = None,
-        allow_redirects: bool = None,
-        proxies: typing.MutableMapping[str, str] = None,
-        hooks: typing.Any = None,
-        stream: bool = None,
-        verify: typing.Union[bool, str] = None,
-        cert: typing.Union[str, typing.Tuple[str, str]] = None,
+        url: httpx._types.URLTypes,
+        *,
+        content: httpx._types.RequestContent = None,
+        data: httpx._types.RequestData = None,
+        files: httpx._types.RequestFiles = None,
         json: typing.Any = None,
-    ) -> requests.Response:
-        url = urljoin(self.base_url, url)
+        params: httpx._types.QueryParamTypes = None,
+        headers: httpx._types.HeaderTypes = None,
+        cookies: httpx._types.CookieTypes = None,
+        auth: typing.Union[
+            httpx._types.AuthTypes, httpx._client.UseClientDefault
+        ] = httpx._client.USE_CLIENT_DEFAULT,
+        follow_redirects: bool = None,
+        allow_redirects: bool = None,
+        timeout: typing.Union[
+            httpx._client.TimeoutTypes, httpx._client.UseClientDefault
+        ] = httpx._client.USE_CLIENT_DEFAULT,
+        extensions: dict = None,
+    ) -> httpx.Response:
+        url = self.base_url.join(url)
+        redirect = self._choose_redirect_arg(follow_redirects, allow_redirects)
         return super().request(
             method,
+            url,
+            content=content,
+            data=data,
+            files=files,
+            json=json,
+            params=params,
+            headers=headers,
+            cookies=cookies,
+            auth=auth,
+            follow_redirects=redirect,
+            timeout=timeout,
+            extensions=extensions,
+        )
+
+    def get(  # type: ignore[override]
+        self,
+        url: httpx._types.URLTypes,
+        *,
+        params: httpx._types.QueryParamTypes = None,
+        headers: httpx._types.HeaderTypes = None,
+        cookies: httpx._types.CookieTypes = None,
+        auth: typing.Union[
+            httpx._types.AuthTypes, httpx._client.UseClientDefault
+        ] = httpx._client.USE_CLIENT_DEFAULT,
+        follow_redirects: bool = None,
+        allow_redirects: bool = None,
+        timeout: typing.Union[
+            httpx._client.TimeoutTypes, httpx._client.UseClientDefault
+        ] = httpx._client.USE_CLIENT_DEFAULT,
+        extensions: dict = None,
+    ) -> httpx.Response:
+        redirect = self._choose_redirect_arg(follow_redirects, allow_redirects)
+        return super().get(
+            url,
+            params=params,
+            headers=headers,
+            cookies=cookies,
+            auth=auth,
+            follow_redirects=redirect,
+            timeout=timeout,
+            extensions=extensions,
+        )
+
+    def options(  # type: ignore[override]
+        self,
+        url: httpx._types.URLTypes,
+        *,
+        params: httpx._types.QueryParamTypes = None,
+        headers: httpx._types.HeaderTypes = None,
+        cookies: httpx._types.CookieTypes = None,
+        auth: typing.Union[
+            httpx._types.AuthTypes, httpx._client.UseClientDefault
+        ] = httpx._client.USE_CLIENT_DEFAULT,
+        follow_redirects: bool = None,
+        allow_redirects: bool = None,
+        timeout: typing.Union[
+            httpx._client.TimeoutTypes, httpx._client.UseClientDefault
+        ] = httpx._client.USE_CLIENT_DEFAULT,
+        extensions: dict = None,
+    ) -> httpx.Response:
+        redirect = self._choose_redirect_arg(follow_redirects, allow_redirects)
+        return super().options(
+            url,
+            params=params,
+            headers=headers,
+            cookies=cookies,
+            auth=auth,
+            follow_redirects=redirect,
+            timeout=timeout,
+            extensions=extensions,
+        )
+
+    def head(  # type: ignore[override]
+        self,
+        url: httpx._types.URLTypes,
+        *,
+        params: httpx._types.QueryParamTypes = None,
+        headers: httpx._types.HeaderTypes = None,
+        cookies: httpx._types.CookieTypes = None,
+        auth: typing.Union[
+            httpx._types.AuthTypes, httpx._client.UseClientDefault
+        ] = httpx._client.USE_CLIENT_DEFAULT,
+        follow_redirects: bool = None,
+        allow_redirects: bool = None,
+        timeout: typing.Union[
+            httpx._client.TimeoutTypes, httpx._client.UseClientDefault
+        ] = httpx._client.USE_CLIENT_DEFAULT,
+        extensions: dict = None,
+    ) -> httpx.Response:
+        redirect = self._choose_redirect_arg(follow_redirects, allow_redirects)
+        return super().head(
             url,
             params=params,
+            headers=headers,
+            cookies=cookies,
+            auth=auth,
+            follow_redirects=redirect,
+            timeout=timeout,
+            extensions=extensions,
+        )
+
+    def post(  # type: ignore[override]
+        self,
+        url: httpx._types.URLTypes,
+        *,
+        content: httpx._types.RequestContent = None,
+        data: httpx._types.RequestData = None,
+        files: httpx._types.RequestFiles = None,
+        json: typing.Any = None,
+        params: httpx._types.QueryParamTypes = None,
+        headers: httpx._types.HeaderTypes = None,
+        cookies: httpx._types.CookieTypes = None,
+        auth: typing.Union[
+            httpx._types.AuthTypes, httpx._client.UseClientDefault
+        ] = httpx._client.USE_CLIENT_DEFAULT,
+        follow_redirects: bool = None,
+        allow_redirects: bool = None,
+        timeout: typing.Union[
+            httpx._client.TimeoutTypes, httpx._client.UseClientDefault
+        ] = httpx._client.USE_CLIENT_DEFAULT,
+        extensions: dict = None,
+    ) -> httpx.Response:
+        redirect = self._choose_redirect_arg(follow_redirects, allow_redirects)
+        return super().post(
+            url,
+            content=content,
             data=data,
+            files=files,
+            json=json,
+            params=params,
             headers=headers,
             cookies=cookies,
+            auth=auth,
+            follow_redirects=redirect,
+            timeout=timeout,
+            extensions=extensions,
+        )
+
+    def put(  # type: ignore[override]
+        self,
+        url: httpx._types.URLTypes,
+        *,
+        content: httpx._types.RequestContent = None,
+        data: httpx._types.RequestData = None,
+        files: httpx._types.RequestFiles = None,
+        json: typing.Any = None,
+        params: httpx._types.QueryParamTypes = None,
+        headers: httpx._types.HeaderTypes = None,
+        cookies: httpx._types.CookieTypes = None,
+        auth: typing.Union[
+            httpx._types.AuthTypes, httpx._client.UseClientDefault
+        ] = httpx._client.USE_CLIENT_DEFAULT,
+        follow_redirects: bool = None,
+        allow_redirects: bool = None,
+        timeout: typing.Union[
+            httpx._client.TimeoutTypes, httpx._client.UseClientDefault
+        ] = httpx._client.USE_CLIENT_DEFAULT,
+        extensions: dict = None,
+    ) -> httpx.Response:
+        redirect = self._choose_redirect_arg(follow_redirects, allow_redirects)
+        return super().put(
+            url,
+            content=content,
+            data=data,
             files=files,
+            json=json,
+            params=params,
+            headers=headers,
+            cookies=cookies,
             auth=auth,
+            follow_redirects=redirect,
             timeout=timeout,
-            allow_redirects=allow_redirects,
-            proxies=proxies,
-            hooks=hooks,
-            stream=stream,
-            verify=verify,
-            cert=cert,
+            extensions=extensions,
+        )
+
+    def patch(  # type: ignore[override]
+        self,
+        url: httpx._types.URLTypes,
+        *,
+        content: httpx._types.RequestContent = None,
+        data: httpx._types.RequestData = None,
+        files: httpx._types.RequestFiles = None,
+        json: typing.Any = None,
+        params: httpx._types.QueryParamTypes = None,
+        headers: httpx._types.HeaderTypes = None,
+        cookies: httpx._types.CookieTypes = None,
+        auth: typing.Union[
+            httpx._types.AuthTypes, httpx._client.UseClientDefault
+        ] = httpx._client.USE_CLIENT_DEFAULT,
+        follow_redirects: bool = None,
+        allow_redirects: bool = None,
+        timeout: typing.Union[
+            httpx._client.TimeoutTypes, httpx._client.UseClientDefault
+        ] = httpx._client.USE_CLIENT_DEFAULT,
+        extensions: dict = None,
+    ) -> httpx.Response:
+        redirect = self._choose_redirect_arg(follow_redirects, allow_redirects)
+        return super().patch(
+            url,
+            content=content,
+            data=data,
+            files=files,
             json=json,
+            params=params,
+            headers=headers,
+            cookies=cookies,
+            auth=auth,
+            follow_redirects=redirect,
+            timeout=timeout,
+            extensions=extensions,
+        )
+
+    def delete(  # type: ignore[override]
+        self,
+        url: httpx._types.URLTypes,
+        *,
+        params: httpx._types.QueryParamTypes = None,
+        headers: httpx._types.HeaderTypes = None,
+        cookies: httpx._types.CookieTypes = None,
+        auth: typing.Union[
+            httpx._types.AuthTypes, httpx._client.UseClientDefault
+        ] = httpx._client.USE_CLIENT_DEFAULT,
+        follow_redirects: bool = None,
+        allow_redirects: bool = None,
+        timeout: typing.Union[
+            httpx._client.TimeoutTypes, httpx._client.UseClientDefault
+        ] = httpx._client.USE_CLIENT_DEFAULT,
+        extensions: dict = None,
+    ) -> httpx.Response:
+        redirect = self._choose_redirect_arg(follow_redirects, allow_redirects)
+        return super().delete(
+            url,
+            params=params,
+            headers=headers,
+            cookies=cookies,
+            auth=auth,
+            follow_redirects=redirect,
+            timeout=timeout,
+            extensions=extensions,
         )
 
     def websocket_connect(
index 910afd9f84f906d5544559b06c8060f4df6d2dd3..ca3d4f47b0dee237f160d88e3021bdc1878c59cc 100644 (file)
@@ -279,9 +279,12 @@ def test_cors_allow_all_methods(test_client_factory):
 
     headers = {"Origin": "https://example.org"}
 
-    for method in ("delete", "get", "head", "options", "patch", "post", "put"):
+    for method in ("patch", "post", "put"):
         response = getattr(client, method)("/", headers=headers, json={})
         assert response.status_code == 200
+    for method in ("delete", "get", "head", "options"):
+        response = getattr(client, method)("/", headers=headers)
+        assert response.status_code == 200
 
 
 def test_cors_allow_origin_regex(test_client_factory):
index a044153a66b57f9296281311de60e78c4218dd49..3f43506c4117139b3d0098a019e6dd50c3227e6a 100644 (file)
@@ -5,6 +5,7 @@ from starlette.middleware import Middleware
 from starlette.middleware.sessions import SessionMiddleware
 from starlette.responses import JSONResponse
 from starlette.routing import Mount, Route
+from starlette.testclient import TestClient
 
 
 def view_session(request):
@@ -74,7 +75,8 @@ def test_session_expires(test_client_factory):
     expired_session_match = re.search(r"session=([^;]*);", expired_cookie_header)
     assert expired_session_match is not None
     expired_session_value = expired_session_match[1]
-    response = client.get("/view_session", cookies={"session": expired_session_value})
+    client = test_client_factory(app, cookies={"session": expired_session_value})
+    response = client.get("/view_session")
     assert response.json() == {"session": {}}
 
 
@@ -128,7 +130,8 @@ def test_session_cookie_subpath(test_client_factory):
     )
     app = Starlette(routes=[Mount("/second_app", app=second_app)])
     client = test_client_factory(app, base_url="http://testserver/second_app")
-    response = client.post("second_app/update_session", json={"some": "data"})
+    response = client.post("/second_app/update_session", json={"some": "data"})
+    assert response.status_code == 200
     cookie = response.headers["set-cookie"]
     cookie_path_match = re.search(r"; path=(\S+);", cookie)
     assert cookie_path_match is not None
@@ -150,7 +153,8 @@ def test_invalid_session_cookie(test_client_factory):
     assert response.json() == {"session": {"some": "data"}}
 
     # we expect it to not raise an exception if we provide a bogus session cookie
-    response = client.get("/view_session", cookies={"session": "invalid"})
+    client = test_client_factory(app, cookies={"session": "invalid"})
+    response = client.get("/view_session")
     assert response.json() == {"session": {}}
 
 
@@ -162,7 +166,7 @@ def test_session_cookie(test_client_factory):
         ],
         middleware=[Middleware(SessionMiddleware, secret_key="example", max_age=None)],
     )
-    client = test_client_factory(app)
+    client: TestClient = test_client_factory(app)
 
     response = client.post("/update_session", json={"some": "data"})
     assert response.json() == {"session": {"some": "data"}}
@@ -171,6 +175,6 @@ def test_session_cookie(test_client_factory):
     set_cookie = response.headers["set-cookie"]
     assert "Max-Age" not in set_cookie
 
-    client.cookies.clear_session_cookies()
+    client.cookies.delete("session")
     response = client.get("/view_session")
     assert response.json() == {"session": {}}
index b7f8cad8c8c13e2d641351531f46c53d145e99f7..4792424abce46789cad2473ed4d881ae3ed21189 100644 (file)
@@ -9,7 +9,6 @@ from starlette.formparsers import MultiPartException, UploadFile, _user_safe_dec
 from starlette.requests import Request
 from starlette.responses import JSONResponse
 from starlette.routing import Mount
-from starlette.testclient import TestClient
 
 
 class ForceMultipartDict(dict):
@@ -114,7 +113,7 @@ def test_multipart_request_files(tmpdir, test_client_factory):
             "test": {
                 "filename": "test.txt",
                 "content": "<file content>",
-                "content_type": "",
+                "content_type": "text/plain",
             }
         }
 
@@ -154,7 +153,7 @@ def test_multipart_request_multiple_files(tmpdir, test_client_factory):
             "test1": {
                 "filename": "test1.txt",
                 "content": "<file1 content>",
-                "content_type": "",
+                "content_type": "text/plain",
             },
             "test2": {
                 "filename": "test2.txt",
@@ -193,8 +192,8 @@ def test_multipart_request_multiple_files_with_headers(tmpdir, test_client_facto
                         "content-disposition",
                         'form-data; name="test2"; filename="test2.txt"',
                     ],
-                    ["content-type", "text/plain"],
                     ["x-custom", "f2"],
+                    ["content-type", "text/plain"],
                 ],
             },
         }
@@ -213,7 +212,7 @@ def test_multi_items(tmpdir, test_client_factory):
     with open(path1, "rb") as f1, open(path2, "rb") as f2:
         response = client.post(
             "/",
-            data=[("test1", "abc")],
+            data={"test1": "abc"},
             files=[("test1", f1), ("test1", ("test2.txt", f2, "text/plain"))],
         )
         assert response.json() == {
@@ -222,7 +221,7 @@ def test_multi_items(tmpdir, test_client_factory):
                 {
                     "filename": "test1.txt",
                     "content": "<file1 content>",
-                    "content_type": "",
+                    "content_type": "text/plain",
                 },
                 {
                     "filename": "test2.txt",
@@ -401,9 +400,7 @@ def test_user_safe_decode_ignores_wrong_charset():
         (Starlette(routes=[Mount("/", app=app)]), does_not_raise()),
     ],
 )
-def test_missing_boundary_parameter(
-    app, expectation, test_client_factory: typing.Callable[..., TestClient]
-) -> None:
+def test_missing_boundary_parameter(app, expectation, test_client_factory) -> None:
     client = test_client_factory(app)
     with expectation:
         res = client.post(
@@ -428,7 +425,7 @@ def test_missing_boundary_parameter(
     ],
 )
 def test_missing_name_parameter_on_content_disposition(
-    app, expectation, test_client_factory: typing.Callable[..., TestClient]
+    app, expectation, test_client_factory
 ):
     client = test_client_factory(app)
     with expectation:
index 033df1e6aa94cf5e3848283baa9084d0848d41ae..7422ad72a9d7be2d73677478a9b4895edecbafed 100644 (file)
@@ -431,7 +431,7 @@ def test_chunked_encoding(test_client_factory):
 
     def post_body():
         yield b"foo"
-        yield "bar"
+        yield b"bar"
 
     response = client.post("/", data=post_body())
     assert response.json() == {"body": "foobar"}
index 2030a73824d0fcc1ddf189542272265ccaaef787..608842da2ebdf6a1cd86ae0fbb13718499560ed8 100644 (file)
@@ -191,12 +191,12 @@ def test_response_phrase(test_client_factory):
     app = Response(status_code=204)
     client = test_client_factory(app)
     response = client.get("/")
-    assert response.reason == "No Content"
+    assert response.reason_phrase == "No Content"
 
     app = Response(b"", status_code=123)
     client = test_client_factory(app)
     response = client.get("/")
-    assert response.reason == ""
+    assert response.reason_phrase == ""
 
 
 def test_file_response(tmpdir, test_client_factory):
index 84f1f5d46ad554aa365708c272c1644159ae4281..142c2a00b5699d76adc5dec329576b59ff8e0369 100644 (file)
@@ -171,7 +171,7 @@ def test_staticfiles_prevents_breaking_out_of_directory(tmpdir):
         file.write("outside root dir")
 
     app = StaticFiles(directory=directory)
-    # We can't test this with 'requests', so we test the app directly here.
+    # We can't test this with 'httpx', so we test the app directly here.
     path = app.get_path({"path": "/../example.txt"})
     scope = {"method": "GET"}