From: Tom Christie Date: Wed, 7 Nov 2018 13:38:29 +0000 (+0000) Subject: Add `async run_in_threadpool(func, *args, **kwargs)` (#192) X-Git-Tag: 0.7.4~1 X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=b686174d0e74fc1f0b61da231f062e50f8c39e68;p=thirdparty%2Fstarlette.git Add `async run_in_threadpool(func, *args, **kwargs)` (#192) * Add run_in_threadpool * Fix contextvars support with threadpools --- diff --git a/starlette/background.py b/starlette/background.py index a666cc61..dbe75924 100644 --- a/starlette/background.py +++ b/starlette/background.py @@ -1,7 +1,8 @@ import asyncio -import functools import typing +from starlette.concurrency import run_in_threadpool + class BackgroundTask: def __init__( @@ -10,11 +11,10 @@ class BackgroundTask: self.func = func self.args = args self.kwargs = kwargs + self.is_async = asyncio.iscoroutinefunction(func) async def __call__(self) -> None: - if asyncio.iscoroutinefunction(self.func): - await asyncio.ensure_future(self.func(*self.args, **self.kwargs)) + if self.is_async: + await self.func(*self.args, **self.kwargs) else: - fn = functools.partial(self.func, *self.args, **self.kwargs) - loop = asyncio.get_event_loop() - await loop.run_in_executor(None, fn) + await run_in_threadpool(self.func, *self.args, **self.kwargs) diff --git a/starlette/concurrency.py b/starlette/concurrency.py new file mode 100644 index 00000000..35b58995 --- /dev/null +++ b/starlette/concurrency.py @@ -0,0 +1,24 @@ +import asyncio +import functools +import typing + +try: + import contextvars # Python 3.7+ only. +except ImportError: # pragma: no cover + contextvars = None # type: ignore + + +async def run_in_threadpool( + func: typing.Callable, *args: typing.Any, **kwargs: typing.Any +) -> typing.Any: + loop = asyncio.get_event_loop() + if contextvars is not None: # pragma: no cover + # Ensure we run in the same context + child = functools.partial(func, *args, **kwargs) + context = contextvars.copy_context() + func = context.run + args = (child,) + elif kwargs: # pragma: no cover + # loop.run_in_executor doesn't accept 'kwargs', so bind them in here + func = functools.partial(func, **kwargs) + return await loop.run_in_executor(None, func, *args) diff --git a/starlette/endpoints.py b/starlette/endpoints.py index 7d486a7d..4694515e 100644 --- a/starlette/endpoints.py +++ b/starlette/endpoints.py @@ -3,6 +3,7 @@ import json import typing from starlette import status +from starlette.concurrency import run_in_threadpool from starlette.exceptions import HTTPException from starlette.requests import Request from starlette.responses import PlainTextResponse, Response @@ -23,11 +24,11 @@ class HTTPEndpoint: async def dispatch(self, request: Request) -> Response: handler_name = "get" if request.method == "HEAD" else request.method.lower() handler = getattr(self, handler_name, self.method_not_allowed) - if asyncio.iscoroutinefunction(handler): + is_async = asyncio.iscoroutinefunction(handler) + if is_async: response = await handler(request) else: - loop = asyncio.get_event_loop() - response = await loop.run_in_executor(None, handler, request) + response = await run_in_threadpool(handler, request) return response async def method_not_allowed(self, request: Request) -> Response: diff --git a/starlette/formparsers.py b/starlette/formparsers.py index d5bdc229..4d640743 100644 --- a/starlette/formparsers.py +++ b/starlette/formparsers.py @@ -5,6 +5,7 @@ import typing from enum import Enum from urllib.parse import unquote +from starlette.concurrency import run_in_threadpool from starlette.datastructures import Headers try: @@ -44,19 +45,19 @@ class UploadFile: self._file = tempfile.SpooledTemporaryFile() async def setup(self) -> None: - await self._loop.run_in_executor(None, self.create_tempfile) + await run_in_threadpool(self.create_tempfile) async def write(self, data: bytes) -> None: - await self._loop.run_in_executor(None, self._file.write, data) + await run_in_threadpool(self._file.write, data) async def read(self, size: int = None) -> bytes: - return await self._loop.run_in_executor(None, self._file.read, size) + return await run_in_threadpool(self._file.read, size) async def seek(self, offset: int) -> None: - await self._loop.run_in_executor(None, self._file.seek, offset) + await run_in_threadpool(self._file.seek, offset) async def close(self) -> None: - await self._loop.run_in_executor(None, self._file.close) + await run_in_threadpool(self._file.close) class FormParser: diff --git a/starlette/graphql.py b/starlette/graphql.py index bcab8ddd..df9aaf4d 100644 --- a/starlette/graphql.py +++ b/starlette/graphql.py @@ -1,9 +1,9 @@ -import asyncio import functools import json import typing from starlette import status +from starlette.concurrency import run_in_threadpool from starlette.requests import Request from starlette.responses import HTMLResponse, JSONResponse, PlainTextResponse, Response from starlette.types import ASGIInstance, Receive, Scope, Send @@ -95,11 +95,12 @@ class GraphQLApp: return_promise=True, ) else: - func = functools.partial( - self.schema.execute, variables=variables, operation_name=operation_name + return await run_in_threadpool( + self.schema.execute, + query, + variables=variables, + operation_name=operation_name, ) - loop = asyncio.get_event_loop() - return await loop.run_in_executor(None, func, query) async def handle_graphiql(self, request: Request) -> Response: text = GRAPHIQL.replace("{{REQUEST_PATH}}", json.dumps(request.url.path)) diff --git a/starlette/middleware/wsgi.py b/starlette/middleware/wsgi.py index 77eb8450..ff8d756f 100644 --- a/starlette/middleware/wsgi.py +++ b/starlette/middleware/wsgi.py @@ -3,6 +3,7 @@ import io import sys import typing +from starlette.concurrency import run_in_threadpool from starlette.types import ASGIApp, ASGIInstance, Message, Receive, Scope, Send @@ -80,7 +81,7 @@ class WSGIResponder: body += message.get("body", b"") more_body = message.get("more_body", False) environ = build_environ(self.scope, body) - wsgi = self.loop.run_in_executor(None, self.wsgi, environ, self.start_response) + wsgi = run_in_threadpool(self.wsgi, environ, self.start_response) sender = self.loop.create_task(self.sender(send)) await asyncio.wait_for(wsgi, None) self.send_queue.append(None) diff --git a/starlette/routing.py b/starlette/routing.py index 6d7a8418..66037166 100644 --- a/starlette/routing.py +++ b/starlette/routing.py @@ -5,6 +5,7 @@ import typing from concurrent.futures import ThreadPoolExecutor from enum import Enum +from starlette.concurrency import run_in_threadpool from starlette.datastructures import URL, URLPath from starlette.exceptions import HTTPException from starlette.requests import Request @@ -41,8 +42,7 @@ def request_response(func: typing.Callable) -> ASGIApp: if is_coroutine: response = await func(request) else: - loop = asyncio.get_event_loop() - response = await loop.run_in_executor(None, func, request) + response = await run_in_threadpool(func, request) await response(receive, send) return awaitable