]> git.ipfire.org Git - thirdparty/starlette.git/commitdiff
First pass adding next qparam to redirect (#920)
authorErik <eraker@gmail.com>
Wed, 16 Feb 2022 11:19:08 +0000 (03:19 -0800)
committerGitHub <noreply@github.com>
Wed, 16 Feb 2022 11:19:08 +0000 (11:19 +0000)
Co-authored-by: Tom Christie <tom@tomchristie.com>
docs/authentication.md
starlette/authentication.py
tests/test_authentication.py

index 48eba6ca2f0575b816cfc6663ff7f2559ebc7032..d6cec3fb2053167f4a697818d5c242f7c22d2a4b 100644 (file)
@@ -131,6 +131,29 @@ async def dashboard(request):
     ...
 ```
 
+When redirecting users, the page you redirect them to will include URL they originally requested at the `next` query param:
+
+```python
+from starlette.authentication import requires
+from starlette.responses import RedirectResponse
+
+
+@requires('authenticated', redirect='login')
+async def admin(request):
+    ...
+
+
+async def login(request):
+    if request.method == "POST":
+        # Now that the user is authenticated,
+        # we can send them to their original request destination
+        if request.user.is_authenticated:
+            next_url = request.query_params.get("next")
+            if next_url:
+                return RedirectResponse(next_url)
+            return RedirectResponse("/")
+```
+
 For class-based endpoints, you should wrap the decorator
 around a method on the class.
 
index b4882070d5a8f0cf15e6613e8c7c1ce4d24ff8ee..1a4cba377980220e87cd5a01dff21a440a42c1e8 100644 (file)
@@ -2,6 +2,7 @@ import asyncio
 import functools
 import inspect
 import typing
+from urllib.parse import urlencode
 
 from starlette.exceptions import HTTPException
 from starlette.requests import HTTPConnection, Request
@@ -63,9 +64,12 @@ def requires(
 
                 if not has_required_scope(request, scopes_list):
                     if redirect is not None:
-                        return RedirectResponse(
-                            url=request.url_for(redirect), status_code=303
+                        orig_request_qparam = urlencode({"next": str(request.url)})
+                        next_url = "{redirect_path}?{orig_request}".format(
+                            redirect_path=request.url_for(redirect),
+                            orig_request=orig_request_qparam,
                         )
+                        return RedirectResponse(url=next_url, status_code=303)
                     raise HTTPException(status_code=status_code)
                 return await func(*args, **kwargs)
 
@@ -80,9 +84,12 @@ def requires(
 
                 if not has_required_scope(request, scopes_list):
                     if redirect is not None:
-                        return RedirectResponse(
-                            url=request.url_for(redirect), status_code=303
+                        orig_request_qparam = urlencode({"next": str(request.url)})
+                        next_url = "{redirect_path}?{orig_request}".format(
+                            redirect_path=request.url_for(redirect),
+                            orig_request=orig_request_qparam,
                         )
+                        return RedirectResponse(url=next_url, status_code=303)
                     raise HTTPException(status_code=status_code)
                 return func(*args, **kwargs)
 
index 65b49c3ca50c6e715f857528cbca5758c8183f4c..af0beafd024c32799bb2fa16e8778dd5c07a9340 100644 (file)
@@ -1,5 +1,6 @@
 import base64
 import binascii
+from urllib.parse import urlencode
 
 import pytest
 
@@ -305,7 +306,10 @@ def test_authentication_redirect(test_client_factory):
     with test_client_factory(app) as client:
         response = client.get("/admin")
         assert response.status_code == 200
-        assert response.url == "http://testserver/"
+        url = "{}?{}".format(
+            "http://testserver/", urlencode({"next": "http://testserver/admin"})
+        )
+        assert response.url == url
 
         response = client.get("/admin", auth=("tomchristie", "example"))
         assert response.status_code == 200
@@ -313,7 +317,10 @@ def test_authentication_redirect(test_client_factory):
 
         response = client.get("/admin/sync")
         assert response.status_code == 200
-        assert response.url == "http://testserver/"
+        url = "{}?{}".format(
+            "http://testserver/", urlencode({"next": "http://testserver/admin/sync"})
+        )
+        assert response.url == url
 
         response = client.get("/admin/sync", auth=("tomchristie", "example"))
         assert response.status_code == 200