]> git.ipfire.org Git - thirdparty/starlette.git/commitdiff
Improve detection of async callables (#1444)
authorMarcelo Trylesinski <marcelotryle@gmail.com>
Sat, 28 May 2022 07:28:58 +0000 (09:28 +0200)
committerGitHub <noreply@github.com>
Sat, 28 May 2022 07:28:58 +0000 (09:28 +0200)
* Improve detection of coroutine functions

* Remove test from background tasks

* Fix coverage

* Add test for nested functools

* Ignore coverage

* Deprecate iscoroutinefunction_or_partial

* Ignore coverage for iscoroutinefunction_or_partial

* Rename `iscoroutinefunction` to `is_async_callable`

starlette/_utils.py [new file with mode: 0644]
starlette/authentication.py
starlette/background.py
starlette/endpoints.py
starlette/middleware/errors.py
starlette/middleware/exceptions.py
starlette/routing.py
starlette/testclient.py
tests/test__utils.py [new file with mode: 0644]

diff --git a/starlette/_utils.py b/starlette/_utils.py
new file mode 100644 (file)
index 0000000..0710aeb
--- /dev/null
@@ -0,0 +1,12 @@
+import asyncio
+import functools
+import typing
+
+
+def is_async_callable(obj: typing.Any) -> bool:
+    while isinstance(obj, functools.partial):
+        obj = obj.func
+
+    return asyncio.iscoroutinefunction(obj) or (
+        callable(obj) and asyncio.iscoroutinefunction(obj.__call__)
+    )
index 17f4a5eada706e0a26f7a8dcaaf774c49de67edd..4affb4383f24ce2583d10c66ca46ae2ed087699c 100644 (file)
@@ -1,9 +1,9 @@
-import asyncio
 import functools
 import inspect
 import typing
 from urllib.parse import urlencode
 
+from starlette._utils import is_async_callable
 from starlette.exceptions import HTTPException
 from starlette.requests import HTTPConnection, Request
 from starlette.responses import RedirectResponse, Response
@@ -53,7 +53,7 @@ def requires(
 
             return websocket_wrapper
 
-        elif asyncio.iscoroutinefunction(func):
+        elif is_async_callable(func):
             # Handle async request/response functions.
             @functools.wraps(func)
             async def async_wrapper(
index 145324e3fffaff33fe8b3b2f4a702f2ac96a0eae..4aaf7ae3cf4addc0e0c64434419abf9462599078 100644 (file)
@@ -1,4 +1,3 @@
-import asyncio
 import sys
 import typing
 
@@ -7,6 +6,7 @@ if sys.version_info >= (3, 10):  # pragma: no cover
 else:  # pragma: no cover
     from typing_extensions import ParamSpec
 
+from starlette._utils import is_async_callable
 from starlette.concurrency import run_in_threadpool
 
 P = ParamSpec("P")
@@ -19,7 +19,7 @@ class BackgroundTask:
         self.func = func
         self.args = args
         self.kwargs = kwargs
-        self.is_async = asyncio.iscoroutinefunction(func)
+        self.is_async = is_async_callable(func)
 
     async def __call__(self) -> None:
         if self.is_async:
index f2468a326d53755f38bfd413825c4c6c0672b556..156663e4901bff6d2c2645f3ee1f49b4678c8e73 100644 (file)
@@ -1,8 +1,8 @@
-import asyncio
 import json
 import typing
 
 from starlette import status
+from starlette._utils import is_async_callable
 from starlette.concurrency import run_in_threadpool
 from starlette.exceptions import HTTPException
 from starlette.requests import Request
@@ -37,7 +37,7 @@ class HTTPEndpoint:
         handler: typing.Callable[[Request], typing.Any] = getattr(
             self, handler_name, self.method_not_allowed
         )
-        is_async = asyncio.iscoroutinefunction(handler)
+        is_async = is_async_callable(handler)
         if is_async:
             response = await handler(request)
         else:
index acb1930f33ffeb70f7f2693e07f0fdde4a5dcc6f..052b885f43649f228c8992b044ab4944a6fe954e 100644 (file)
@@ -1,9 +1,9 @@
-import asyncio
 import html
 import inspect
 import traceback
 import typing
 
+from starlette._utils import is_async_callable
 from starlette.concurrency import run_in_threadpool
 from starlette.requests import Request
 from starlette.responses import HTMLResponse, PlainTextResponse, Response
@@ -170,7 +170,7 @@ class ServerErrorMiddleware:
                 response = self.error_response(request, exc)
             else:
                 # Use an installed 500 error handler.
-                if asyncio.iscoroutinefunction(self.handler):
+                if is_async_callable(self.handler):
                     response = await self.handler(request, exc)
                 else:
                     response = await run_in_threadpool(self.handler, request, exc)
index a3b4633d2764dea49498dbc79975fc7e5f58f4bc..42fd41ae2fd528aeffd675bdc062483f8e0a0e28 100644 (file)
@@ -1,6 +1,6 @@
-import asyncio
 import typing
 
+from starlette._utils import is_async_callable
 from starlette.concurrency import run_in_threadpool
 from starlette.exceptions import HTTPException
 from starlette.requests import Request
@@ -79,7 +79,7 @@ class ExceptionMiddleware:
                 raise RuntimeError(msg) from exc
 
             request = Request(scope, receive=receive)
-            if asyncio.iscoroutinefunction(handler):
+            if is_async_callable(handler):
                 response = await handler(request, exc)
             else:
                 response = await run_in_threadpool(handler, request, exc)
index 67f12e311109df15e448d816970a9b7bc67b701f..7e10b16f942025f34979a5e87f14bb3d6e7bd3e7 100644 (file)
@@ -1,4 +1,3 @@
-import asyncio
 import contextlib
 import functools
 import inspect
@@ -10,6 +9,7 @@ import warnings
 from contextlib import asynccontextmanager
 from enum import Enum
 
+from starlette._utils import is_async_callable
 from starlette.concurrency import run_in_threadpool
 from starlette.convertors import CONVERTOR_TYPES, Convertor
 from starlette.datastructures import URL, Headers, URLPath
@@ -37,11 +37,16 @@ class Match(Enum):
     FULL = 2
 
 
-def iscoroutinefunction_or_partial(obj: typing.Any) -> bool:
+def iscoroutinefunction_or_partial(obj: typing.Any) -> bool:  # pragma: no cover
     """
     Correctly determines if an object is a coroutine function,
     including those wrapped in functools.partial objects.
     """
+    warnings.warn(
+        "iscoroutinefunction_or_partial is deprecated, "
+        "and will be removed in a future release.",
+        DeprecationWarning,
+    )
     while isinstance(obj, functools.partial):
         obj = obj.func
     return inspect.iscoroutinefunction(obj)
@@ -52,7 +57,7 @@ def request_response(func: typing.Callable) -> ASGIApp:
     Takes a function or coroutine `func(request) -> response`,
     and returns an ASGI application.
     """
-    is_coroutine = iscoroutinefunction_or_partial(func)
+    is_coroutine = is_async_callable(func)
 
     async def app(scope: Scope, receive: Receive, send: Send) -> None:
         request = Request(scope, receive=receive, send=send)
@@ -603,7 +608,7 @@ class Router:
         Run any `.on_startup` event handlers.
         """
         for handler in self.on_startup:
-            if asyncio.iscoroutinefunction(handler):
+            if is_async_callable(handler):
                 await handler()
             else:
                 handler()
@@ -613,7 +618,7 @@ class Router:
         Run any `.on_shutdown` event handlers.
         """
         for handler in self.on_shutdown:
-            if asyncio.iscoroutinefunction(handler):
+            if is_async_callable(handler):
                 await handler()
             else:
                 handler()
index 7567d18fd69acb20f5aced1dc95671ad5120a482..efe2b493bb0dd42f3a3262a5eb1fc3d450a5fb51 100644 (file)
@@ -1,4 +1,3 @@
-import asyncio
 import contextlib
 import http
 import inspect
@@ -16,6 +15,7 @@ import anyio.abc
 import requests
 from anyio.streams.stapled import StapledObjectStream
 
+from starlette._utils import is_async_callable
 from starlette.types import Message, Receive, Scope, Send
 from starlette.websockets import WebSocketDisconnect
 
@@ -84,10 +84,7 @@ def _get_reason_phrase(status_code: int) -> str:
 def _is_asgi3(app: typing.Union[ASGI2App, ASGI3App]) -> bool:
     if inspect.isclass(app):
         return hasattr(app, "__await__")
-    elif inspect.isfunction(app):
-        return asyncio.iscoroutinefunction(app)
-    call = getattr(app, "__call__", None)
-    return asyncio.iscoroutinefunction(call)
+    return is_async_callable(app)
 
 
 class _WrapASGI2:
diff --git a/tests/test__utils.py b/tests/test__utils.py
new file mode 100644 (file)
index 0000000..fac57a2
--- /dev/null
@@ -0,0 +1,79 @@
+import functools
+
+from starlette._utils import is_async_callable
+
+
+def test_async_func():
+    async def async_func():
+        ...  # pragma: no cover
+
+    def func():
+        ...  # pragma: no cover
+
+    assert is_async_callable(async_func)
+    assert not is_async_callable(func)
+
+
+def test_async_partial():
+    async def async_func(a, b):
+        ...  # pragma: no cover
+
+    def func(a, b):
+        ...  # pragma: no cover
+
+    partial = functools.partial(async_func, 1)
+    assert is_async_callable(partial)
+
+    partial = functools.partial(func, 1)
+    assert not is_async_callable(partial)
+
+
+def test_async_method():
+    class Async:
+        async def method(self):
+            ...  # pragma: no cover
+
+    class Sync:
+        def method(self):
+            ...  # pragma: no cover
+
+    assert is_async_callable(Async().method)
+    assert not is_async_callable(Sync().method)
+
+
+def test_async_object_call():
+    class Async:
+        async def __call__(self):
+            ...  # pragma: no cover
+
+    class Sync:
+        def __call__(self):
+            ...  # pragma: no cover
+
+    assert is_async_callable(Async())
+    assert not is_async_callable(Sync())
+
+
+def test_async_partial_object_call():
+    class Async:
+        async def __call__(self, a, b):
+            ...  # pragma: no cover
+
+    class Sync:
+        def __call__(self, a, b):
+            ...  # pragma: no cover
+
+    partial = functools.partial(Async(), 1)
+    assert is_async_callable(partial)
+
+    partial = functools.partial(Sync(), 1)
+    assert not is_async_callable(partial)
+
+
+def test_async_nested_partial():
+    async def async_func(a, b):
+        ...  # pragma: no cover
+
+    partial = functools.partial(async_func, b=2)
+    nested_partial = functools.partial(partial, a=1)
+    assert is_async_callable(nested_partial)