]> git.ipfire.org Git - thirdparty/starlette.git/commitdiff
Ensure accurate `root_path` removal in `get_route_path` function (#2600)
authorGabriel Figueiredo <53541827+gabriel-f-santos@users.noreply.github.com>
Sun, 1 Sep 2024 15:12:43 +0000 (12:12 -0300)
committerGitHub <noreply@github.com>
Sun, 1 Sep 2024 15:12:43 +0000 (17:12 +0200)
* fix: regex inside function get_route_path to remove root_path

* fix: apply format ruff

* fix: mypy

---------

Co-authored-by: Marcelo Trylesinski <marcelotryle@gmail.com>
starlette/_utils.py
tests/test__utils.py
tests/test_routing.py

index 90bd346fd1211c99e0c0787940f6501263a7dc62..f615eeea4fee078d82fc8a132b0f52728cd47cdd 100644 (file)
@@ -85,5 +85,5 @@ def collapse_excgroups() -> typing.Generator[None, None, None]:
 
 def get_route_path(scope: Scope) -> str:
     root_path = scope.get("root_path", "")
-    route_path = re.sub(r"^" + root_path, "", scope["path"])
+    route_path = re.sub(r"^" + root_path + r"(?=/|$)", "", scope["path"])
     return route_path
index 2ecf4748318a680f44576850148427cc4e6282df..f46775b4b2490def15b90d6dc49d3b70e45d3182 100644 (file)
@@ -1,7 +1,10 @@
 import functools
 from typing import Any
 
-from starlette._utils import is_async_callable
+import pytest
+
+from starlette._utils import get_route_path, is_async_callable
+from starlette.types import Scope
 
 
 def test_async_func() -> None:
@@ -78,3 +81,15 @@ def test_async_nested_partial() -> None:
     partial = functools.partial(async_func, b=2)
     nested_partial = functools.partial(partial, a=1)
     assert is_async_callable(nested_partial)
+
+
+@pytest.mark.parametrize(
+    "scope, expected_result",
+    [
+        ({"path": "/foo-123/bar", "root_path": "/foo"}, "/foo-123/bar"),
+        ({"path": "/foo/bar", "root_path": "/foo"}, "/bar"),
+        ({"path": "/foo", "root_path": "/foo"}, ""),
+    ],
+)
+def test_get_route_path(scope: Scope, expected_result: str) -> None:
+    assert get_route_path(scope) == expected_result
index 9fa44def4c04c3a58c58e7c678fb13db85472e9b..6bb398ba5d7fa69ff11bba4313ee2fb1c4902139 100644 (file)
@@ -1221,6 +1221,12 @@ echo_paths_routes = [
         name="path",
         methods=["GET"],
     ),
+    Route(
+        "/root-queue/path",
+        functools.partial(echo_paths, name="queue_path"),
+        name="queue_path",
+        methods=["POST"],
+    ),
     Mount("/asgipath", app=functools.partial(pure_asgi_echo_paths, name="asgipath")),
     Mount(
         "/sub",
@@ -1266,3 +1272,11 @@ def test_paths_with_root_path(test_client_factory: TestClientFactory) -> None:
         "path": "/root/sub/path",
         "root_path": "/root/sub",
     }
+
+    response = client.post("/root/root-queue/path")
+    assert response.status_code == 200
+    assert response.json() == {
+        "name": "queue_path",
+        "path": "/root/root-queue/path",
+        "root_path": "/root",
+    }