From: Marcelo Trylesinski Date: Sat, 28 May 2022 07:28:58 +0000 (+0200) Subject: Improve detection of async callables (#1444) X-Git-Tag: 0.20.1~1 X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=5b058faa108ebe0e7d8810e84bf88cb37744faf8;p=thirdparty%2Fstarlette.git Improve detection of async callables (#1444) * Improve detection of coroutine functions * Remove test from background tasks * Fix coverage * Add test for nested functools * Ignore coverage * Deprecate iscoroutinefunction_or_partial * Ignore coverage for iscoroutinefunction_or_partial * Rename `iscoroutinefunction` to `is_async_callable` --- diff --git a/starlette/_utils.py b/starlette/_utils.py new file mode 100644 index 00000000..0710aebd --- /dev/null +++ b/starlette/_utils.py @@ -0,0 +1,12 @@ +import asyncio +import functools +import typing + + +def is_async_callable(obj: typing.Any) -> bool: + while isinstance(obj, functools.partial): + obj = obj.func + + return asyncio.iscoroutinefunction(obj) or ( + callable(obj) and asyncio.iscoroutinefunction(obj.__call__) + ) diff --git a/starlette/authentication.py b/starlette/authentication.py index 17f4a5ea..4affb438 100644 --- a/starlette/authentication.py +++ b/starlette/authentication.py @@ -1,9 +1,9 @@ -import asyncio import functools import inspect import typing from urllib.parse import urlencode +from starlette._utils import is_async_callable from starlette.exceptions import HTTPException from starlette.requests import HTTPConnection, Request from starlette.responses import RedirectResponse, Response @@ -53,7 +53,7 @@ def requires( return websocket_wrapper - elif asyncio.iscoroutinefunction(func): + elif is_async_callable(func): # Handle async request/response functions. @functools.wraps(func) async def async_wrapper( diff --git a/starlette/background.py b/starlette/background.py index 145324e3..4aaf7ae3 100644 --- a/starlette/background.py +++ b/starlette/background.py @@ -1,4 +1,3 @@ -import asyncio import sys import typing @@ -7,6 +6,7 @@ if sys.version_info >= (3, 10): # pragma: no cover else: # pragma: no cover from typing_extensions import ParamSpec +from starlette._utils import is_async_callable from starlette.concurrency import run_in_threadpool P = ParamSpec("P") @@ -19,7 +19,7 @@ class BackgroundTask: self.func = func self.args = args self.kwargs = kwargs - self.is_async = asyncio.iscoroutinefunction(func) + self.is_async = is_async_callable(func) async def __call__(self) -> None: if self.is_async: diff --git a/starlette/endpoints.py b/starlette/endpoints.py index f2468a32..156663e4 100644 --- a/starlette/endpoints.py +++ b/starlette/endpoints.py @@ -1,8 +1,8 @@ -import asyncio import json import typing from starlette import status +from starlette._utils import is_async_callable from starlette.concurrency import run_in_threadpool from starlette.exceptions import HTTPException from starlette.requests import Request @@ -37,7 +37,7 @@ class HTTPEndpoint: handler: typing.Callable[[Request], typing.Any] = getattr( self, handler_name, self.method_not_allowed ) - is_async = asyncio.iscoroutinefunction(handler) + is_async = is_async_callable(handler) if is_async: response = await handler(request) else: diff --git a/starlette/middleware/errors.py b/starlette/middleware/errors.py index acb1930f..052b885f 100644 --- a/starlette/middleware/errors.py +++ b/starlette/middleware/errors.py @@ -1,9 +1,9 @@ -import asyncio import html import inspect import traceback import typing +from starlette._utils import is_async_callable from starlette.concurrency import run_in_threadpool from starlette.requests import Request from starlette.responses import HTMLResponse, PlainTextResponse, Response @@ -170,7 +170,7 @@ class ServerErrorMiddleware: response = self.error_response(request, exc) else: # Use an installed 500 error handler. - if asyncio.iscoroutinefunction(self.handler): + if is_async_callable(self.handler): response = await self.handler(request, exc) else: response = await run_in_threadpool(self.handler, request, exc) diff --git a/starlette/middleware/exceptions.py b/starlette/middleware/exceptions.py index a3b4633d..42fd41ae 100644 --- a/starlette/middleware/exceptions.py +++ b/starlette/middleware/exceptions.py @@ -1,6 +1,6 @@ -import asyncio import typing +from starlette._utils import is_async_callable from starlette.concurrency import run_in_threadpool from starlette.exceptions import HTTPException from starlette.requests import Request @@ -79,7 +79,7 @@ class ExceptionMiddleware: raise RuntimeError(msg) from exc request = Request(scope, receive=receive) - if asyncio.iscoroutinefunction(handler): + if is_async_callable(handler): response = await handler(request, exc) else: response = await run_in_threadpool(handler, request, exc) diff --git a/starlette/routing.py b/starlette/routing.py index 67f12e31..7e10b16f 100644 --- a/starlette/routing.py +++ b/starlette/routing.py @@ -1,4 +1,3 @@ -import asyncio import contextlib import functools import inspect @@ -10,6 +9,7 @@ import warnings from contextlib import asynccontextmanager from enum import Enum +from starlette._utils import is_async_callable from starlette.concurrency import run_in_threadpool from starlette.convertors import CONVERTOR_TYPES, Convertor from starlette.datastructures import URL, Headers, URLPath @@ -37,11 +37,16 @@ class Match(Enum): FULL = 2 -def iscoroutinefunction_or_partial(obj: typing.Any) -> bool: +def iscoroutinefunction_or_partial(obj: typing.Any) -> bool: # pragma: no cover """ Correctly determines if an object is a coroutine function, including those wrapped in functools.partial objects. """ + warnings.warn( + "iscoroutinefunction_or_partial is deprecated, " + "and will be removed in a future release.", + DeprecationWarning, + ) while isinstance(obj, functools.partial): obj = obj.func return inspect.iscoroutinefunction(obj) @@ -52,7 +57,7 @@ def request_response(func: typing.Callable) -> ASGIApp: Takes a function or coroutine `func(request) -> response`, and returns an ASGI application. """ - is_coroutine = iscoroutinefunction_or_partial(func) + is_coroutine = is_async_callable(func) async def app(scope: Scope, receive: Receive, send: Send) -> None: request = Request(scope, receive=receive, send=send) @@ -603,7 +608,7 @@ class Router: Run any `.on_startup` event handlers. """ for handler in self.on_startup: - if asyncio.iscoroutinefunction(handler): + if is_async_callable(handler): await handler() else: handler() @@ -613,7 +618,7 @@ class Router: Run any `.on_shutdown` event handlers. """ for handler in self.on_shutdown: - if asyncio.iscoroutinefunction(handler): + if is_async_callable(handler): await handler() else: handler() diff --git a/starlette/testclient.py b/starlette/testclient.py index 7567d18f..efe2b493 100644 --- a/starlette/testclient.py +++ b/starlette/testclient.py @@ -1,4 +1,3 @@ -import asyncio import contextlib import http import inspect @@ -16,6 +15,7 @@ import anyio.abc import requests from anyio.streams.stapled import StapledObjectStream +from starlette._utils import is_async_callable from starlette.types import Message, Receive, Scope, Send from starlette.websockets import WebSocketDisconnect @@ -84,10 +84,7 @@ def _get_reason_phrase(status_code: int) -> str: def _is_asgi3(app: typing.Union[ASGI2App, ASGI3App]) -> bool: if inspect.isclass(app): return hasattr(app, "__await__") - elif inspect.isfunction(app): - return asyncio.iscoroutinefunction(app) - call = getattr(app, "__call__", None) - return asyncio.iscoroutinefunction(call) + return is_async_callable(app) class _WrapASGI2: diff --git a/tests/test__utils.py b/tests/test__utils.py new file mode 100644 index 00000000..fac57a2e --- /dev/null +++ b/tests/test__utils.py @@ -0,0 +1,79 @@ +import functools + +from starlette._utils import is_async_callable + + +def test_async_func(): + async def async_func(): + ... # pragma: no cover + + def func(): + ... # pragma: no cover + + assert is_async_callable(async_func) + assert not is_async_callable(func) + + +def test_async_partial(): + async def async_func(a, b): + ... # pragma: no cover + + def func(a, b): + ... # pragma: no cover + + partial = functools.partial(async_func, 1) + assert is_async_callable(partial) + + partial = functools.partial(func, 1) + assert not is_async_callable(partial) + + +def test_async_method(): + class Async: + async def method(self): + ... # pragma: no cover + + class Sync: + def method(self): + ... # pragma: no cover + + assert is_async_callable(Async().method) + assert not is_async_callable(Sync().method) + + +def test_async_object_call(): + class Async: + async def __call__(self): + ... # pragma: no cover + + class Sync: + def __call__(self): + ... # pragma: no cover + + assert is_async_callable(Async()) + assert not is_async_callable(Sync()) + + +def test_async_partial_object_call(): + class Async: + async def __call__(self, a, b): + ... # pragma: no cover + + class Sync: + def __call__(self, a, b): + ... # pragma: no cover + + partial = functools.partial(Async(), 1) + assert is_async_callable(partial) + + partial = functools.partial(Sync(), 1) + assert not is_async_callable(partial) + + +def test_async_nested_partial(): + async def async_func(a, b): + ... # pragma: no cover + + partial = functools.partial(async_func, b=2) + nested_partial = functools.partial(partial, a=1) + assert is_async_callable(nested_partial)