--- /dev/null
+import asyncio
+import functools
+import typing
+
+
+def is_async_callable(obj: typing.Any) -> bool:
+ while isinstance(obj, functools.partial):
+ obj = obj.func
+
+ return asyncio.iscoroutinefunction(obj) or (
+ callable(obj) and asyncio.iscoroutinefunction(obj.__call__)
+ )
-import asyncio
import functools
import inspect
import typing
from urllib.parse import urlencode
+from starlette._utils import is_async_callable
from starlette.exceptions import HTTPException
from starlette.requests import HTTPConnection, Request
from starlette.responses import RedirectResponse, Response
return websocket_wrapper
- elif asyncio.iscoroutinefunction(func):
+ elif is_async_callable(func):
# Handle async request/response functions.
@functools.wraps(func)
async def async_wrapper(
-import asyncio
import sys
import typing
else: # pragma: no cover
from typing_extensions import ParamSpec
+from starlette._utils import is_async_callable
from starlette.concurrency import run_in_threadpool
P = ParamSpec("P")
self.func = func
self.args = args
self.kwargs = kwargs
- self.is_async = asyncio.iscoroutinefunction(func)
+ self.is_async = is_async_callable(func)
async def __call__(self) -> None:
if self.is_async:
-import asyncio
import json
import typing
from starlette import status
+from starlette._utils import is_async_callable
from starlette.concurrency import run_in_threadpool
from starlette.exceptions import HTTPException
from starlette.requests import Request
handler: typing.Callable[[Request], typing.Any] = getattr(
self, handler_name, self.method_not_allowed
)
- is_async = asyncio.iscoroutinefunction(handler)
+ is_async = is_async_callable(handler)
if is_async:
response = await handler(request)
else:
-import asyncio
import html
import inspect
import traceback
import typing
+from starlette._utils import is_async_callable
from starlette.concurrency import run_in_threadpool
from starlette.requests import Request
from starlette.responses import HTMLResponse, PlainTextResponse, Response
response = self.error_response(request, exc)
else:
# Use an installed 500 error handler.
- if asyncio.iscoroutinefunction(self.handler):
+ if is_async_callable(self.handler):
response = await self.handler(request, exc)
else:
response = await run_in_threadpool(self.handler, request, exc)
-import asyncio
import typing
+from starlette._utils import is_async_callable
from starlette.concurrency import run_in_threadpool
from starlette.exceptions import HTTPException
from starlette.requests import Request
raise RuntimeError(msg) from exc
request = Request(scope, receive=receive)
- if asyncio.iscoroutinefunction(handler):
+ if is_async_callable(handler):
response = await handler(request, exc)
else:
response = await run_in_threadpool(handler, request, exc)
-import asyncio
import contextlib
import functools
import inspect
from contextlib import asynccontextmanager
from enum import Enum
+from starlette._utils import is_async_callable
from starlette.concurrency import run_in_threadpool
from starlette.convertors import CONVERTOR_TYPES, Convertor
from starlette.datastructures import URL, Headers, URLPath
FULL = 2
-def iscoroutinefunction_or_partial(obj: typing.Any) -> bool:
+def iscoroutinefunction_or_partial(obj: typing.Any) -> bool: # pragma: no cover
"""
Correctly determines if an object is a coroutine function,
including those wrapped in functools.partial objects.
"""
+ warnings.warn(
+ "iscoroutinefunction_or_partial is deprecated, "
+ "and will be removed in a future release.",
+ DeprecationWarning,
+ )
while isinstance(obj, functools.partial):
obj = obj.func
return inspect.iscoroutinefunction(obj)
Takes a function or coroutine `func(request) -> response`,
and returns an ASGI application.
"""
- is_coroutine = iscoroutinefunction_or_partial(func)
+ is_coroutine = is_async_callable(func)
async def app(scope: Scope, receive: Receive, send: Send) -> None:
request = Request(scope, receive=receive, send=send)
Run any `.on_startup` event handlers.
"""
for handler in self.on_startup:
- if asyncio.iscoroutinefunction(handler):
+ if is_async_callable(handler):
await handler()
else:
handler()
Run any `.on_shutdown` event handlers.
"""
for handler in self.on_shutdown:
- if asyncio.iscoroutinefunction(handler):
+ if is_async_callable(handler):
await handler()
else:
handler()
-import asyncio
import contextlib
import http
import inspect
import requests
from anyio.streams.stapled import StapledObjectStream
+from starlette._utils import is_async_callable
from starlette.types import Message, Receive, Scope, Send
from starlette.websockets import WebSocketDisconnect
def _is_asgi3(app: typing.Union[ASGI2App, ASGI3App]) -> bool:
if inspect.isclass(app):
return hasattr(app, "__await__")
- elif inspect.isfunction(app):
- return asyncio.iscoroutinefunction(app)
- call = getattr(app, "__call__", None)
- return asyncio.iscoroutinefunction(call)
+ return is_async_callable(app)
class _WrapASGI2:
--- /dev/null
+import functools
+
+from starlette._utils import is_async_callable
+
+
+def test_async_func():
+ async def async_func():
+ ... # pragma: no cover
+
+ def func():
+ ... # pragma: no cover
+
+ assert is_async_callable(async_func)
+ assert not is_async_callable(func)
+
+
+def test_async_partial():
+ async def async_func(a, b):
+ ... # pragma: no cover
+
+ def func(a, b):
+ ... # pragma: no cover
+
+ partial = functools.partial(async_func, 1)
+ assert is_async_callable(partial)
+
+ partial = functools.partial(func, 1)
+ assert not is_async_callable(partial)
+
+
+def test_async_method():
+ class Async:
+ async def method(self):
+ ... # pragma: no cover
+
+ class Sync:
+ def method(self):
+ ... # pragma: no cover
+
+ assert is_async_callable(Async().method)
+ assert not is_async_callable(Sync().method)
+
+
+def test_async_object_call():
+ class Async:
+ async def __call__(self):
+ ... # pragma: no cover
+
+ class Sync:
+ def __call__(self):
+ ... # pragma: no cover
+
+ assert is_async_callable(Async())
+ assert not is_async_callable(Sync())
+
+
+def test_async_partial_object_call():
+ class Async:
+ async def __call__(self, a, b):
+ ... # pragma: no cover
+
+ class Sync:
+ def __call__(self, a, b):
+ ... # pragma: no cover
+
+ partial = functools.partial(Async(), 1)
+ assert is_async_callable(partial)
+
+ partial = functools.partial(Sync(), 1)
+ assert not is_async_callable(partial)
+
+
+def test_async_nested_partial():
+ async def async_func(a, b):
+ ... # pragma: no cover
+
+ partial = functools.partial(async_func, b=2)
+ nested_partial = functools.partial(partial, a=1)
+ assert is_async_callable(nested_partial)