call: Optional[Callable] = None,
request_param_name: Optional[str] = None,
websocket_param_name: Optional[str] = None,
+ http_connection_param_name: Optional[str] = None,
response_param_name: Optional[str] = None,
background_tasks_param_name: Optional[str] = None,
security_scopes_param_name: Optional[str] = None,
self.security_requirements = security_schemes or []
self.request_param_name = request_param_name
self.websocket_param_name = websocket_param_name
+ self.http_connection_param_name = http_connection_param_name
self.response_param_name = response_param_name
self.background_tasks_param_name = background_tasks_param_name
self.security_scopes = security_scopes
from starlette.background import BackgroundTasks
from starlette.concurrency import run_in_threadpool
from starlette.datastructures import FormData, Headers, QueryParams, UploadFile
-from starlette.requests import Request
+from starlette.requests import HTTPConnection, Request
from starlette.responses import Response
from starlette.websockets import WebSocket
elif lenient_issubclass(param.annotation, WebSocket):
dependant.websocket_param_name = param.name
return True
+ elif lenient_issubclass(param.annotation, HTTPConnection):
+ dependant.http_connection_param_name = param.name
+ return True
elif lenient_issubclass(param.annotation, Response):
dependant.response_param_name = param.name
return True
)
values.update(body_values)
errors.extend(body_errors)
+ if dependant.http_connection_param_name:
+ values[dependant.http_connection_param_name] = request
if dependant.request_param_name and isinstance(request, Request):
values[dependant.request_param_name] = request
elif dependant.websocket_param_name and isinstance(request, WebSocket):
-from starlette.requests import Request # noqa
+from starlette.requests import HTTPConnection, Request # noqa
--- /dev/null
+from fastapi import Depends, FastAPI
+from fastapi.requests import HTTPConnection
+from fastapi.testclient import TestClient
+from starlette.websockets import WebSocket
+
+app = FastAPI()
+app.state.value = 42
+
+
+async def extract_value_from_http_connection(conn: HTTPConnection):
+ return conn.app.state.value
+
+
+@app.get("/http")
+async def get_value_by_http(value: int = Depends(extract_value_from_http_connection)):
+ return value
+
+
+@app.websocket("/ws")
+async def get_value_by_ws(
+ websocket: WebSocket, value: int = Depends(extract_value_from_http_connection)
+):
+ await websocket.accept()
+ await websocket.send_json(value)
+ await websocket.close()
+
+
+client = TestClient(app)
+
+
+def test_value_extracting_by_http():
+ response = client.get("/http")
+ assert response.status_code == 200
+ assert response.json() == 42
+
+
+def test_value_extracting_by_ws():
+ with client.websocket_connect("/ws") as websocket:
+ assert websocket.receive_json() == 42