From 76cd611b507635a325a24c741da10081d62c1249 Mon Sep 17 00:00:00 2001 From: Amin Alaee Date: Sat, 11 Dec 2021 14:35:23 +0100 Subject: [PATCH] Add support for functools.partial in WebsocketRoute (#1356) * Add support for functools.partial in WebsocketRoute * remove commented code * Refactor tests for partian endpoint and ws --- starlette/routing.py | 5 ++- tests/test_routing.py | 81 ++++++++++++++++++++++++++++--------------- 2 files changed, 57 insertions(+), 29 deletions(-) diff --git a/starlette/routing.py b/starlette/routing.py index 3c11c1b0..982980c3 100644 --- a/starlette/routing.py +++ b/starlette/routing.py @@ -276,7 +276,10 @@ class WebSocketRoute(BaseRoute): self.endpoint = endpoint self.name = get_name(endpoint) if name is None else name - if inspect.isfunction(endpoint) or inspect.ismethod(endpoint): + endpoint_handler = endpoint + while isinstance(endpoint_handler, functools.partial): + endpoint_handler = endpoint_handler.func + if inspect.isfunction(endpoint_handler) or inspect.ismethod(endpoint_handler): # Endpoint is function or method. Treat it as `func(websocket)`. self.app = websocket_session(endpoint) else: diff --git a/tests/test_routing.py b/tests/test_routing.py index e1374cc5..dcb99653 100644 --- a/tests/test_routing.py +++ b/tests/test_routing.py @@ -32,6 +32,28 @@ def user_no_match(request): # pragma: no cover return Response(content, media_type="text/plain") +async def partial_endpoint(arg, request): + return JSONResponse({"arg": arg}) + + +async def partial_ws_endpoint(websocket: WebSocket): + await websocket.accept() + await websocket.send_json({"url": str(websocket.url)}) + await websocket.close() + + +class PartialRoutes: + @classmethod + async def async_endpoint(cls, arg, request): + return JSONResponse({"arg": arg}) + + @classmethod + async def async_ws_endpoint(cls, websocket: WebSocket): + await websocket.accept() + await websocket.send_json({"url": str(websocket.url)}) + await websocket.close() + + app = Router( [ Route("/", endpoint=homepage, methods=["GET"]), @@ -44,6 +66,21 @@ app = Router( Route("/nomatch", endpoint=user_no_match), ], ), + Mount( + "/partial", + routes=[ + Route("/", endpoint=functools.partial(partial_endpoint, "foo")), + Route( + "/cls", + endpoint=functools.partial(PartialRoutes.async_endpoint, "foo"), + ), + WebSocketRoute("/ws", endpoint=functools.partial(partial_ws_endpoint)), + WebSocketRoute( + "/ws/cls", + endpoint=functools.partial(PartialRoutes.async_ws_endpoint), + ), + ], + ), Mount("/static", app=Response("xxxxx", media_type="image/png")), ] ) @@ -91,14 +128,14 @@ def path_with_parentheses(request): @app.websocket_route("/ws") -async def websocket_endpoint(session): +async def websocket_endpoint(session: WebSocket): await session.accept() await session.send_text("Hello, world!") await session.close() @app.websocket_route("/ws/{room}") -async def websocket_params(session): +async def websocket_params(session: WebSocket): await session.accept() await session.send_text(f"Hello, {session.path_params['room']}!") await session.close() @@ -628,40 +665,28 @@ def test_raise_on_shutdown(test_client_factory): pass # pragma: nocover -class AsyncEndpointClassMethod: - @classmethod - async def async_endpoint(cls, arg, request): - return JSONResponse({"arg": arg}) - - -async def _partial_async_endpoint(arg, request): - return JSONResponse({"arg": arg}) - - -partial_async_endpoint = functools.partial(_partial_async_endpoint, "foo") -partial_cls_async_endpoint = functools.partial( - AsyncEndpointClassMethod.async_endpoint, "foo" -) - -partial_async_app = Router( - routes=[ - Route("/", partial_async_endpoint), - Route("/cls", partial_cls_async_endpoint), - ] -) - - def test_partial_async_endpoint(test_client_factory): - test_client = test_client_factory(partial_async_app) - response = test_client.get("/") + test_client = test_client_factory(app) + response = test_client.get("/partial") assert response.status_code == 200 assert response.json() == {"arg": "foo"} - cls_method_response = test_client.get("/cls") + cls_method_response = test_client.get("/partial/cls") assert cls_method_response.status_code == 200 assert cls_method_response.json() == {"arg": "foo"} +def test_partial_async_ws_endpoint(test_client_factory): + test_client = test_client_factory(app) + with test_client.websocket_connect("/partial/ws") as websocket: + data = websocket.receive_json() + assert data == {"url": "ws://testserver/partial/ws"} + + with test_client.websocket_connect("/partial/ws/cls") as websocket: + data = websocket.receive_json() + assert data == {"url": "ws://testserver/partial/ws/cls"} + + def test_duplicated_param_names(): with pytest.raises( ValueError, -- 2.47.3