]> git.ipfire.org Git - thirdparty/starlette.git/commitdiff
fix(testclient): exclude query sting from `raw_path` (#2716)
authorHao Guan <gh@raptium.net>
Mon, 18 Nov 2024 19:28:43 +0000 (03:28 +0800)
committerGitHub <noreply@github.com>
Mon, 18 Nov 2024 19:28:43 +0000 (20:28 +0100)
Co-authored-by: Marcelo Trylesinski <marcelotryle@gmail.com>
starlette/testclient.py
tests/test_testclient.py

index 5143c4c57f18fb7365a798362c4dff1ca9f83e6e..645ca10900a64658d5e76c80807d476bf2e45c6b 100644 (file)
@@ -281,7 +281,7 @@ class _TestClientTransport(httpx.BaseTransport):
             scope = {
                 "type": "websocket",
                 "path": unquote(path),
-                "raw_path": raw_path,
+                "raw_path": raw_path.split(b"?", 1)[0],
                 "root_path": self.root_path,
                 "scheme": scheme,
                 "query_string": query.encode(),
@@ -300,7 +300,7 @@ class _TestClientTransport(httpx.BaseTransport):
             "http_version": "1.1",
             "method": request.method,
             "path": unquote(path),
-            "raw_path": raw_path,
+            "raw_path": raw_path.split(b"?", 1)[0],
             "root_path": self.root_path,
             "scheme": scheme,
             "query_string": query.encode(),
index 92f16d33649fd76914654ac25358e7712b38077d..68593a9ac9c340095e89c36fa65014a28b8903eb 100644 (file)
@@ -378,3 +378,27 @@ def test_merge_url(test_client_factory: TestClientFactory) -> None:
     client = test_client_factory(app, base_url="http://testserver/api/v1/")
     response = client.get("/bar")
     assert response.text == "/api/v1/bar"
+
+
+def test_raw_path_with_querystring(test_client_factory: TestClientFactory) -> None:
+    async def app(scope: Scope, receive: Receive, send: Send) -> None:
+        response = Response(scope.get("raw_path"))
+        await response(scope, receive, send)
+
+    client = test_client_factory(app)
+    response = client.get("/hello-world", params={"foo": "bar"})
+    assert response.content == b"/hello-world"
+
+
+def test_websocket_raw_path_without_params(test_client_factory: TestClientFactory) -> None:
+    async def app(scope: Scope, receive: Receive, send: Send) -> None:
+        websocket = WebSocket(scope, receive=receive, send=send)
+        await websocket.accept()
+        raw_path = scope.get("raw_path")
+        assert raw_path is not None
+        await websocket.send_bytes(raw_path)
+
+    client = test_client_factory(app)
+    with client.websocket_connect("/hello-world", params={"foo": "bar"}) as websocket:
+        data = websocket.receive_bytes()
+        assert data == b"/hello-world"