]> git.ipfire.org Git - thirdparty/starlette.git/commitdiff
Add `client` parameter to `TestClient` (#2810)
authorIrfanuddin Shafi Ahmed <irfanudeen08@gmail.com>
Sat, 28 Dec 2024 06:05:09 +0000 (10:05 +0400)
committerGitHub <noreply@github.com>
Sat, 28 Dec 2024 06:05:09 +0000 (06:05 +0000)
starlette/testclient.py
tests/test_testclient.py
tests/types.py

index 2c096aa226b88a4d87ea4a7df787e70dc26c21f7..8e908d36f7cc7ec7f7d6d1dc913fb8cd1c1f4053 100644 (file)
@@ -234,6 +234,7 @@ class _TestClientTransport(httpx.BaseTransport):
         raise_server_exceptions: bool = True,
         root_path: str = "",
         *,
+        client: tuple[str, int],
         app_state: dict[str, typing.Any],
     ) -> None:
         self.app = app
@@ -241,6 +242,7 @@ class _TestClientTransport(httpx.BaseTransport):
         self.root_path = root_path
         self.portal_factory = portal_factory
         self.app_state = app_state
+        self.client = client
 
     def handle_request(self, request: httpx.Request) -> httpx.Response:
         scheme = request.url.scheme
@@ -285,7 +287,7 @@ class _TestClientTransport(httpx.BaseTransport):
                 "scheme": scheme,
                 "query_string": query.encode(),
                 "headers": headers,
-                "client": ["testclient", 50000],
+                "client": self.client,
                 "server": [host, port],
                 "subprotocols": subprotocols,
                 "state": self.app_state.copy(),
@@ -304,7 +306,7 @@ class _TestClientTransport(httpx.BaseTransport):
             "scheme": scheme,
             "query_string": query.encode(),
             "headers": headers,
-            "client": ["testclient", 50000],
+            "client": self.client,
             "server": [host, port],
             "extensions": {"http.response.debug": {}},
             "state": self.app_state.copy(),
@@ -409,6 +411,7 @@ class TestClient(httpx.Client):
         cookies: httpx._types.CookieTypes | None = None,
         headers: dict[str, str] | None = None,
         follow_redirects: bool = True,
+        client: tuple[str, int] = ("testclient", 50000),
     ) -> None:
         self.async_backend = _AsyncBackend(backend=backend, backend_options=backend_options or {})
         if _is_asgi3(app):
@@ -424,6 +427,7 @@ class TestClient(httpx.Client):
             raise_server_exceptions=raise_server_exceptions,
             root_path=root_path,
             app_state=self.app_state,
+            client=client,
         )
         if headers is None:
             headers = {}
index 68593a9ac9c340095e89c36fa65014a28b8903eb..279b81d918be918c40b589f5170a90012872382c 100644 (file)
@@ -282,6 +282,19 @@ def test_client(test_client_factory: TestClientFactory) -> None:
     assert response.json() == {"host": "testclient", "port": 50000}
 
 
+def test_client_custom_client(test_client_factory: TestClientFactory) -> None:
+    async def app(scope: Scope, receive: Receive, send: Send) -> None:
+        client = scope.get("client")
+        assert client is not None
+        host, port = client
+        response = JSONResponse({"host": host, "port": port})
+        await response(scope, receive, send)
+
+    client = test_client_factory(app, client=("192.168.0.1", 3000))
+    response = client.get("/")
+    assert response.json() == {"host": "192.168.0.1", "port": 3000}
+
+
 @pytest.mark.parametrize("param", ("2020-07-14T00:00:00+00:00", "España", "voilà"))
 def test_query_params(test_client_factory: TestClientFactory, param: str) -> None:
     def homepage(request: Request) -> Response:
index 1cbacf107705cbe223f636d7ea798678f32ad006..e4769d308e8591d83ff84b866190397e2b330afc 100644 (file)
@@ -19,6 +19,7 @@ if TYPE_CHECKING:
             cookies: httpx._types.CookieTypes | None = None,
             headers: dict[str, str] | None = None,
             follow_redirects: bool = True,
+            client: tuple[str, int] = ("testclient", 50000),
         ) -> TestClient: ...
 else:  # pragma: no cover