]> git.ipfire.org Git - thirdparty/fastapi/fastapi.git/commitdiff
✨ Add support for injecting HTTPConnection (#1827)
authorNik <sidnev.nick@gmail.com>
Sun, 9 Aug 2020 13:56:41 +0000 (16:56 +0300)
committerGitHub <noreply@github.com>
Sun, 9 Aug 2020 13:56:41 +0000 (15:56 +0200)
fastapi/dependencies/models.py
fastapi/dependencies/utils.py
fastapi/requests.py
tests/test_http_connection_injection.py [new file with mode: 0644]

index 58685221154e0cae6cd34fe06934b5f7844fb660..8e0c7830a8c91948d4174801903f0b88cb9eeb2b 100644 (file)
@@ -34,6 +34,7 @@ class Dependant:
         call: Optional[Callable] = None,
         request_param_name: Optional[str] = None,
         websocket_param_name: Optional[str] = None,
+        http_connection_param_name: Optional[str] = None,
         response_param_name: Optional[str] = None,
         background_tasks_param_name: Optional[str] = None,
         security_scopes_param_name: Optional[str] = None,
@@ -50,6 +51,7 @@ class Dependant:
         self.security_requirements = security_schemes or []
         self.request_param_name = request_param_name
         self.websocket_param_name = websocket_param_name
+        self.http_connection_param_name = http_connection_param_name
         self.response_param_name = response_param_name
         self.background_tasks_param_name = background_tasks_param_name
         self.security_scopes = security_scopes
index a45a8fe099581d4073f623e984884d38a9cd250a..6c49941ad4a389488fa0016b54a2693f5b9b2b07 100644 (file)
@@ -41,7 +41,7 @@ from pydantic.utils import lenient_issubclass
 from starlette.background import BackgroundTasks
 from starlette.concurrency import run_in_threadpool
 from starlette.datastructures import FormData, Headers, QueryParams, UploadFile
-from starlette.requests import Request
+from starlette.requests import HTTPConnection, Request
 from starlette.responses import Response
 from starlette.websockets import WebSocket
 
@@ -371,6 +371,9 @@ def add_non_field_param_to_dependency(
     elif lenient_issubclass(param.annotation, WebSocket):
         dependant.websocket_param_name = param.name
         return True
+    elif lenient_issubclass(param.annotation, HTTPConnection):
+        dependant.http_connection_param_name = param.name
+        return True
     elif lenient_issubclass(param.annotation, Response):
         dependant.response_param_name = param.name
         return True
@@ -607,6 +610,8 @@ async def solve_dependencies(
         )
         values.update(body_values)
         errors.extend(body_errors)
+    if dependant.http_connection_param_name:
+        values[dependant.http_connection_param_name] = request
     if dependant.request_param_name and isinstance(request, Request):
         values[dependant.request_param_name] = request
     elif dependant.websocket_param_name and isinstance(request, WebSocket):
index eb13f0380077e852a92218197a8442105a5582e7..06d8f01ccfb4d8af65353e2db331ef1de50b6ce0 100644 (file)
@@ -1 +1 @@
-from starlette.requests import Request  # noqa
+from starlette.requests import HTTPConnection, Request  # noqa
diff --git a/tests/test_http_connection_injection.py b/tests/test_http_connection_injection.py
new file mode 100644 (file)
index 0000000..6e321b5
--- /dev/null
@@ -0,0 +1,39 @@
+from fastapi import Depends, FastAPI
+from fastapi.requests import HTTPConnection
+from fastapi.testclient import TestClient
+from starlette.websockets import WebSocket
+
+app = FastAPI()
+app.state.value = 42
+
+
+async def extract_value_from_http_connection(conn: HTTPConnection):
+    return conn.app.state.value
+
+
+@app.get("/http")
+async def get_value_by_http(value: int = Depends(extract_value_from_http_connection)):
+    return value
+
+
+@app.websocket("/ws")
+async def get_value_by_ws(
+    websocket: WebSocket, value: int = Depends(extract_value_from_http_connection)
+):
+    await websocket.accept()
+    await websocket.send_json(value)
+    await websocket.close()
+
+
+client = TestClient(app)
+
+
+def test_value_extracting_by_http():
+    response = client.get("/http")
+    assert response.status_code == 200
+    assert response.json() == 42
+
+
+def test_value_extracting_by_ws():
+    with client.websocket_connect("/ws") as websocket:
+        assert websocket.receive_json() == 42