From fe961dd22ca69741831d1beb5eac5de0da7d8878 Mon Sep 17 00:00:00 2001 From: Vlad Stefan Munteanu Date: Sun, 8 Nov 2020 03:33:11 +0200 Subject: [PATCH] Allow usage of functools.partial async handlers (#984) * Allow usage of async partial methods * Added test for partial async endpoint * Double quotes vs single quotes * Support multiple levels of partials, check Python < 3.8 * Skip coverage for py3.8 branch Co-authored-by: Florimond Manca --- starlette/routing.py | 20 ++++++++++++++++++-- tests/test_routing.py | 16 ++++++++++++++++ 2 files changed, 34 insertions(+), 2 deletions(-) diff --git a/starlette/routing.py b/starlette/routing.py index ea5d22cb..ce5e4d19 100644 --- a/starlette/routing.py +++ b/starlette/routing.py @@ -1,6 +1,8 @@ import asyncio +import functools import inspect import re +import sys import traceback import typing from enum import Enum @@ -28,12 +30,23 @@ class Match(Enum): FULL = 2 +def iscoroutinefunction_or_partial(obj: typing.Any) -> bool: + """ + Correctly determines if an object is a coroutine function, + with a fix for partials on Python < 3.8. + """ + if sys.version_info < (3, 8): # pragma: no cover + while isinstance(obj, functools.partial): + obj = obj.func + return inspect.iscoroutinefunction(obj) + + def request_response(func: typing.Callable) -> ASGIApp: """ Takes a function or coroutine `func(request) -> response`, and returns an ASGI application. """ - is_coroutine = asyncio.iscoroutinefunction(func) + is_coroutine = iscoroutinefunction_or_partial(func) async def app(scope: Scope, receive: Receive, send: Send) -> None: request = Request(scope, receive=receive, send=send) @@ -169,7 +182,10 @@ class Route(BaseRoute): self.name = get_name(endpoint) if name is None else name self.include_in_schema = include_in_schema - if inspect.isfunction(endpoint) or inspect.ismethod(endpoint): + endpoint_handler = endpoint + while isinstance(endpoint_handler, functools.partial): + endpoint_handler = endpoint_handler.func + if inspect.isfunction(endpoint_handler) or inspect.ismethod(endpoint_handler): # Endpoint is function or method. Treat it as `func(request) -> response`. self.app = request_response(endpoint) if methods is None: diff --git a/tests/test_routing.py b/tests/test_routing.py index b06995ff..27640efe 100644 --- a/tests/test_routing.py +++ b/tests/test_routing.py @@ -1,3 +1,4 @@ +import functools import uuid import pytest @@ -587,3 +588,18 @@ def test_raise_on_shutdown(): with pytest.raises(RuntimeError): with TestClient(app): pass # pragma: nocover + + +async def _partial_async_endpoint(arg, request): + return JSONResponse({"arg": arg}) + + +partial_async_endpoint = functools.partial(_partial_async_endpoint, "foo") + +partial_async_app = Router(routes=[Route("/", partial_async_endpoint)]) + + +def test_partial_async_endpoint(): + response = TestClient(partial_async_app).get("/") + assert response.status_code == 200 + assert response.json() == {"arg": "foo"} -- 2.47.3