]> git.ipfire.org Git - thirdparty/starlette.git/commitdiff
Support `request.url_for` when only "app" scope is avaialable (#2672)
authorJoel Sleppy <jdsleppy@gmail.com>
Sun, 29 Sep 2024 09:40:23 +0000 (05:40 -0400)
committerGitHub <noreply@github.com>
Sun, 29 Sep 2024 09:40:23 +0000 (09:40 +0000)
* Support request.url_for in BaseMiddleware

* Move test to test_requests

* Call the endpoint properly

* fix test

* order the type hint

---------

Co-authored-by: Marcelo Trylesinski <marcelotryle@gmail.com>
starlette/requests.py
tests/middleware/test_base.py
tests/test_requests.py

index 08dbd84d42fd027d6f5f17ee29f8875acc6fc46c..9c6776f0c4fd86720f2b00a43dffdf1edf51f776 100644 (file)
@@ -19,6 +19,7 @@ except ModuleNotFoundError:  # pragma: no cover
 
 
 if typing.TYPE_CHECKING:
+    from starlette.applications import Starlette
     from starlette.routing import Router
 
 
@@ -175,8 +176,10 @@ class HTTPConnection(typing.Mapping[str, typing.Any]):
         return self._state
 
     def url_for(self, name: str, /, **path_params: typing.Any) -> URL:
-        router: Router = self.scope["router"]
-        url_path = router.url_path_for(name, **path_params)
+        url_path_provider: Router | Starlette | None = self.scope.get("router") or self.scope.get("app")
+        if url_path_provider is None:
+            raise RuntimeError("The `url_for` method can only be used inside a Starlette application or with a router.")
+        url_path = url_path_provider.url_path_for(name, **path_params)
         return url_path.make_absolute_url(base_url=self.base_url)
 
 
index e2375a7b9f866de5f0b2427c32d7e775be76df46..041cc7ce2c835cc41c6b74f997775423adc98907 100644 (file)
@@ -2,12 +2,7 @@ from __future__ import annotations
 
 import contextvars
 from contextlib import AsyncExitStack
-from typing import (
-    Any,
-    AsyncGenerator,
-    AsyncIterator,
-    Generator,
-)
+from typing import Any, AsyncGenerator, AsyncIterator, Generator
 
 import anyio
 import pytest
index 2f173713efc472c97e5456df4e82eb4bcf498651..f0494e7513b1f7f1d5f5f9c117856ffeb397e24c 100644 (file)
@@ -6,7 +6,7 @@ from typing import Any, Iterator
 import anyio
 import pytest
 
-from starlette.datastructures import Address, State
+from starlette.datastructures import URL, Address, State
 from starlette.requests import ClientDisconnect, Request
 from starlette.responses import JSONResponse, PlainTextResponse, Response
 from starlette.types import Message, Receive, Scope, Send
@@ -592,3 +592,44 @@ async def test_request_stream_called_twice() -> None:
         assert await s2.__anext__()
     with pytest.raises(StopAsyncIteration):
         await s1.__anext__()
+
+
+def test_request_url_outside_starlette_context(test_client_factory: TestClientFactory) -> None:
+    async def app(scope: Scope, receive: Receive, send: Send) -> None:
+        request = Request(scope, receive)
+        request.url_for("index")
+
+    client = test_client_factory(app)
+    with pytest.raises(
+        RuntimeError,
+        match="The `url_for` method can only be used inside a Starlette application or with a router.",
+    ):
+        client.get("/")
+
+
+def test_request_url_starlette_context(test_client_factory: TestClientFactory) -> None:
+    from starlette.applications import Starlette
+    from starlette.middleware import Middleware
+    from starlette.routing import Route
+    from starlette.types import ASGIApp
+
+    url_for = None
+
+    async def homepage(request: Request) -> Response:
+        return PlainTextResponse("Hello, world!")
+
+    class CustomMiddleware:
+        def __init__(self, app: ASGIApp) -> None:
+            self.app = app
+
+        async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
+            nonlocal url_for
+            request = Request(scope, receive)
+            url_for = request.url_for("homepage")
+            await self.app(scope, receive, send)
+
+    app = Starlette(routes=[Route("/home", homepage)], middleware=[Middleware(CustomMiddleware)])
+
+    client = test_client_factory(app)
+    client.get("/home")
+    assert url_for == URL("http://testserver/home")