]> git.ipfire.org Git - thirdparty/fastapi/fastapi.git/commitdiff
✨ Include route in scope to allow middleware and other tools to extract its informati...
authorSebastián Ramírez <tiangolo@gmail.com>
Mon, 21 Feb 2022 15:51:26 +0000 (16:51 +0100)
committerGitHub <noreply@github.com>
Mon, 21 Feb 2022 15:51:26 +0000 (16:51 +0100)
fastapi/routing.py
tests/test_route_scope.py [new file with mode: 0644]

index f6d5370d6a3c7e647714ef345b8bcbee40b0a01c..7dae04521ca3d1c167bbbbec03cae4de9d8453c6 100644 (file)
@@ -13,6 +13,7 @@ from typing import (
     Optional,
     Sequence,
     Set,
+    Tuple,
     Type,
     Union,
 )
@@ -44,7 +45,7 @@ from starlette.concurrency import run_in_threadpool
 from starlette.exceptions import HTTPException
 from starlette.requests import Request
 from starlette.responses import JSONResponse, Response
-from starlette.routing import BaseRoute
+from starlette.routing import BaseRoute, Match
 from starlette.routing import Mount as Mount  # noqa
 from starlette.routing import (
     compile_path,
@@ -53,7 +54,7 @@ from starlette.routing import (
     websocket_session,
 )
 from starlette.status import WS_1008_POLICY_VIOLATION
-from starlette.types import ASGIApp
+from starlette.types import ASGIApp, Scope
 from starlette.websockets import WebSocket
 
 
@@ -296,6 +297,12 @@ class APIWebSocketRoute(routing.WebSocketRoute):
         )
         self.path_regex, self.path_format, self.param_convertors = compile_path(path)
 
+    def matches(self, scope: Scope) -> Tuple[Match, Scope]:
+        match, child_scope = super().matches(scope)
+        if match != Match.NONE:
+            child_scope["route"] = self
+        return match, child_scope
+
 
 class APIRoute(routing.Route):
     def __init__(
@@ -432,6 +439,12 @@ class APIRoute(routing.Route):
             dependency_overrides_provider=self.dependency_overrides_provider,
         )
 
+    def matches(self, scope: Scope) -> Tuple[Match, Scope]:
+        match, child_scope = super().matches(scope)
+        if match != Match.NONE:
+            child_scope["route"] = self
+        return match, child_scope
+
 
 class APIRouter(routing.Router):
     def __init__(
diff --git a/tests/test_route_scope.py b/tests/test_route_scope.py
new file mode 100644 (file)
index 0000000..a188e9a
--- /dev/null
@@ -0,0 +1,50 @@
+import pytest
+from fastapi import FastAPI, Request, WebSocket, WebSocketDisconnect
+from fastapi.routing import APIRoute, APIWebSocketRoute
+from fastapi.testclient import TestClient
+
+app = FastAPI()
+
+
+@app.get("/users/{user_id}")
+async def get_user(user_id: str, request: Request):
+    route: APIRoute = request.scope["route"]
+    return {"user_id": user_id, "path": route.path}
+
+
+@app.websocket("/items/{item_id}")
+async def websocket_item(item_id: str, websocket: WebSocket):
+    route: APIWebSocketRoute = websocket.scope["route"]
+    await websocket.accept()
+    await websocket.send_json({"item_id": item_id, "path": route.path})
+
+
+client = TestClient(app)
+
+
+def test_get():
+    response = client.get("/users/rick")
+    assert response.status_code == 200, response.text
+    assert response.json() == {"user_id": "rick", "path": "/users/{user_id}"}
+
+
+def test_invalid_method_doesnt_match():
+    response = client.post("/users/rick")
+    assert response.status_code == 405, response.text
+
+
+def test_invalid_path_doesnt_match():
+    response = client.post("/usersx/rick")
+    assert response.status_code == 404, response.text
+
+
+def test_websocket():
+    with client.websocket_connect("/items/portal-gun") as websocket:
+        data = websocket.receive_json()
+        assert data == {"item_id": "portal-gun", "path": "/items/{item_id}"}
+
+
+def test_websocket_invalid_path_doesnt_match():
+    with pytest.raises(WebSocketDisconnect):
+        with client.websocket_connect("/itemsx/portal-gun") as websocket:
+            websocket.receive_json()