]> git.ipfire.org Git - thirdparty/starlette.git/commitdiff
Fix regression on route names with colons (#1675)
authorFlorimond Manca <florimond.manca@protonmail.com>
Sat, 4 Jun 2022 16:14:49 +0000 (18:14 +0200)
committerGitHub <noreply@github.com>
Sat, 4 Jun 2022 16:14:49 +0000 (18:14 +0200)
Co-authored-by: Bodo Graumann <mail@bodograumann.de>
starlette/routing.py
tests/test_routing.py

index 7e10b16f942025f34979a5e87f14bb3d6e7bd3e7..1aa2cdb6de626656820cb2952f686c427f1aacbd 100644 (file)
@@ -111,13 +111,16 @@ def compile_path(
     path: str,
 ) -> typing.Tuple[typing.Pattern, str, typing.Dict[str, Convertor]]:
     """
-    Given a path string, like: "/{username:str}", return a three-tuple
+    Given a path string, like: "/{username:str}",
+    or a host string, like: "{subdomain}.mydomain.org", return a three-tuple
     of (regex, format, {param_name:convertor}).
 
     regex:      "/(?P<username>[^/]+)"
     format:     "/{username}"
     convertors: {"username": StringConvertor()}
     """
+    is_host = not path.startswith("/")
+
     path_regex = "^"
     path_format = ""
     duplicated_params = set()
@@ -150,7 +153,13 @@ def compile_path(
         ending = "s" if len(duplicated_params) > 1 else ""
         raise ValueError(f"Duplicated param name{ending} {names} at path {path}")
 
-    path_regex += re.escape(path[idx:].split(":")[0]) + "$"
+    if is_host:
+        # Align with `Host.matches()` behavior, which ignores port.
+        hostname = path[idx:].split(":")[0]
+        path_regex += re.escape(hostname) + "$"
+    else:
+        path_regex += re.escape(path[idx:]) + "$"
+
     path_format += path[idx:]
 
     return re.compile(path_regex), path_format, param_convertors
@@ -429,6 +438,7 @@ class Host(BaseRoute):
     def __init__(
         self, host: str, app: ASGIApp, name: typing.Optional[str] = None
     ) -> None:
+        assert not host.startswith("/"), "Host must not start with '/'"
         self.host = host
         self.app = app
         self.name = name
index e8adaca48bfc0f11689237c3bd360a1ceb1f87ee..e2b1c3dfcd40d011890d29ab3005e9784f7fb721 100644 (file)
@@ -28,6 +28,11 @@ def user_me(request):
     return Response(content, media_type="text/plain")
 
 
+def disable_user(request):
+    content = "User " + request.path_params["username"] + " disabled"
+    return Response(content, media_type="text/plain")
+
+
 def user_no_match(request):  # pragma: no cover
     content = "User fixed no match"
     return Response(content, media_type="text/plain")
@@ -109,6 +114,7 @@ app = Router(
                 Route("/", endpoint=users),
                 Route("/me", endpoint=user_me),
                 Route("/{username}", endpoint=user),
+                Route("/{username}:disable", endpoint=disable_user, methods=["PUT"]),
                 Route("/nomatch", endpoint=user_no_match),
             ],
         ),
@@ -189,6 +195,11 @@ def test_router(client):
     assert response.url == "http://testserver/users/tomchristie"
     assert response.text == "User tomchristie"
 
+    response = client.put("/users/tomchristie:disable")
+    assert response.status_code == 200
+    assert response.url == "http://testserver/users/tomchristie:disable"
+    assert response.text == "User tomchristie disabled"
+
     response = client.get("/users/nomatch")
     assert response.status_code == 200
     assert response.text == "User nomatch"
@@ -429,7 +440,9 @@ def test_host_routing(test_client_factory):
     response = client.get("/")
     assert response.status_code == 200
 
-    client = test_client_factory(mixed_hosts_app, base_url="https://port.example.org/")
+    client = test_client_factory(
+        mixed_hosts_app, base_url="https://port.example.org:3600/"
+    )
 
     response = client.get("/users")
     assert response.status_code == 404
@@ -437,6 +450,13 @@ def test_host_routing(test_client_factory):
     response = client.get("/")
     assert response.status_code == 200
 
+    # Port in requested Host is irrelevant.
+
+    client = test_client_factory(mixed_hosts_app, base_url="https://port.example.org/")
+
+    response = client.get("/")
+    assert response.status_code == 200
+
     client = test_client_factory(
         mixed_hosts_app, base_url="https://port.example.org:5600/"
     )