]> git.ipfire.org Git - thirdparty/starlette.git/commitdiff
Add support for functools.partial in WebsocketRoute (#1356)
authorAmin Alaee <mohammadamin.alaee@gmail.com>
Sat, 11 Dec 2021 13:35:23 +0000 (14:35 +0100)
committerGitHub <noreply@github.com>
Sat, 11 Dec 2021 13:35:23 +0000 (14:35 +0100)
* Add support for functools.partial in WebsocketRoute

* remove commented code

* Refactor tests for partian endpoint and ws

starlette/routing.py
tests/test_routing.py

index 3c11c1b0cc0152f2955680f2461fe331c0932119..982980c3ce1969001944c11de8f4867bf7eeb69c 100644 (file)
@@ -276,7 +276,10 @@ class WebSocketRoute(BaseRoute):
         self.endpoint = endpoint
         self.name = get_name(endpoint) if name is None else name
 
-        if inspect.isfunction(endpoint) or inspect.ismethod(endpoint):
+        endpoint_handler = endpoint
+        while isinstance(endpoint_handler, functools.partial):
+            endpoint_handler = endpoint_handler.func
+        if inspect.isfunction(endpoint_handler) or inspect.ismethod(endpoint_handler):
             # Endpoint is function or method. Treat it as `func(websocket)`.
             self.app = websocket_session(endpoint)
         else:
index e1374cc5d9409f301e2d302e9532f8c39b41709e..dcb99653108440744234e00b552fe200a877c865 100644 (file)
@@ -32,6 +32,28 @@ def user_no_match(request):  # pragma: no cover
     return Response(content, media_type="text/plain")
 
 
+async def partial_endpoint(arg, request):
+    return JSONResponse({"arg": arg})
+
+
+async def partial_ws_endpoint(websocket: WebSocket):
+    await websocket.accept()
+    await websocket.send_json({"url": str(websocket.url)})
+    await websocket.close()
+
+
+class PartialRoutes:
+    @classmethod
+    async def async_endpoint(cls, arg, request):
+        return JSONResponse({"arg": arg})
+
+    @classmethod
+    async def async_ws_endpoint(cls, websocket: WebSocket):
+        await websocket.accept()
+        await websocket.send_json({"url": str(websocket.url)})
+        await websocket.close()
+
+
 app = Router(
     [
         Route("/", endpoint=homepage, methods=["GET"]),
@@ -44,6 +66,21 @@ app = Router(
                 Route("/nomatch", endpoint=user_no_match),
             ],
         ),
+        Mount(
+            "/partial",
+            routes=[
+                Route("/", endpoint=functools.partial(partial_endpoint, "foo")),
+                Route(
+                    "/cls",
+                    endpoint=functools.partial(PartialRoutes.async_endpoint, "foo"),
+                ),
+                WebSocketRoute("/ws", endpoint=functools.partial(partial_ws_endpoint)),
+                WebSocketRoute(
+                    "/ws/cls",
+                    endpoint=functools.partial(PartialRoutes.async_ws_endpoint),
+                ),
+            ],
+        ),
         Mount("/static", app=Response("xxxxx", media_type="image/png")),
     ]
 )
@@ -91,14 +128,14 @@ def path_with_parentheses(request):
 
 
 @app.websocket_route("/ws")
-async def websocket_endpoint(session):
+async def websocket_endpoint(session: WebSocket):
     await session.accept()
     await session.send_text("Hello, world!")
     await session.close()
 
 
 @app.websocket_route("/ws/{room}")
-async def websocket_params(session):
+async def websocket_params(session: WebSocket):
     await session.accept()
     await session.send_text(f"Hello, {session.path_params['room']}!")
     await session.close()
@@ -628,40 +665,28 @@ def test_raise_on_shutdown(test_client_factory):
             pass  # pragma: nocover
 
 
-class AsyncEndpointClassMethod:
-    @classmethod
-    async def async_endpoint(cls, arg, request):
-        return JSONResponse({"arg": arg})
-
-
-async def _partial_async_endpoint(arg, request):
-    return JSONResponse({"arg": arg})
-
-
-partial_async_endpoint = functools.partial(_partial_async_endpoint, "foo")
-partial_cls_async_endpoint = functools.partial(
-    AsyncEndpointClassMethod.async_endpoint, "foo"
-)
-
-partial_async_app = Router(
-    routes=[
-        Route("/", partial_async_endpoint),
-        Route("/cls", partial_cls_async_endpoint),
-    ]
-)
-
-
 def test_partial_async_endpoint(test_client_factory):
-    test_client = test_client_factory(partial_async_app)
-    response = test_client.get("/")
+    test_client = test_client_factory(app)
+    response = test_client.get("/partial")
     assert response.status_code == 200
     assert response.json() == {"arg": "foo"}
 
-    cls_method_response = test_client.get("/cls")
+    cls_method_response = test_client.get("/partial/cls")
     assert cls_method_response.status_code == 200
     assert cls_method_response.json() == {"arg": "foo"}
 
 
+def test_partial_async_ws_endpoint(test_client_factory):
+    test_client = test_client_factory(app)
+    with test_client.websocket_connect("/partial/ws") as websocket:
+        data = websocket.receive_json()
+        assert data == {"url": "ws://testserver/partial/ws"}
+
+    with test_client.websocket_connect("/partial/ws/cls") as websocket:
+        data = websocket.receive_json()
+        assert data == {"url": "ws://testserver/partial/ws/cls"}
+
+
 def test_duplicated_param_names():
     with pytest.raises(
         ValueError,