From: Antoine Pitrou Date: Thu, 16 Nov 2017 17:28:20 +0000 (+0100) Subject: Issue #2188: check whether a future was cancelled before calling set_result() X-Git-Tag: v5.0.0~38^2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=refs%2Fpull%2F2200%2Fhead;p=thirdparty%2Ftornado.git Issue #2188: check whether a future was cancelled before calling set_result() --- diff --git a/tornado/auth.py b/tornado/auth.py index f6d505a20..caae0ef8d 100644 --- a/tornado/auth.py +++ b/tornado/auth.py @@ -75,7 +75,9 @@ import hmac import time import uuid -from tornado.concurrent import Future, return_future, chain_future, future_set_exc_info +from tornado.concurrent import (Future, return_future, chain_future, + future_set_exc_info, + future_set_result_unless_cancelled) from tornado import gen from tornado import httpclient from tornado import escape @@ -295,7 +297,7 @@ class OpenIdMixin(object): claimed_id = self.get_argument("openid.claimed_id", None) if claimed_id: user["claimed_id"] = claimed_id - future.set_result(user) + future_set_result_unless_cancelled(future, user) def get_auth_http_client(self): """Returns the `.AsyncHTTPClient` instance to be used for auth requests. @@ -519,7 +521,7 @@ class OAuthMixin(object): future.set_exception(AuthError("Error getting user")) return user["access_token"] = access_token - future.set_result(user) + future_set_result_unless_cancelled(future, user) def _oauth_request_parameters(self, url, access_token, parameters={}, method="GET"): @@ -668,7 +670,7 @@ class OAuth2Mixin(object): (response.error, response.request.url))) return - future.set_result(escape.json_decode(response.body)) + future_set_result_unless_cancelled(future, escape.json_decode(response.body)) def get_auth_http_client(self): """Returns the `.AsyncHTTPClient` instance to be used for auth requests. @@ -811,7 +813,7 @@ class TwitterMixin(OAuthMixin): "Error response %s fetching %s" % (response.error, response.request.url))) return - future.set_result(escape.json_decode(response.body)) + future_set_result_unless_cancelled(future, escape.json_decode(response.body)) def _oauth_consumer_token(self): self.require_setting("twitter_consumer_key", "Twitter OAuth") @@ -915,7 +917,7 @@ class GoogleOAuth2Mixin(OAuth2Mixin): return args = escape.json_decode(response.body) - future.set_result(args) + future_set_result_unless_cancelled(future, args) class FacebookGraphMixin(OAuth2Mixin): @@ -1011,7 +1013,7 @@ class FacebookGraphMixin(OAuth2Mixin): def _on_get_user_info(self, future, session, fields, user): if user is None: - future.set_result(None) + future_set_result_unless_cancelled(future, None) return fieldmap = {} @@ -1024,7 +1026,7 @@ class FacebookGraphMixin(OAuth2Mixin): # This should change in Tornado 5.0. fieldmap.update({"access_token": session["access_token"], "session_expires": str(session.get("expires_in"))}) - future.set_result(fieldmap) + future_set_result_unless_cancelled(future, fieldmap) @_auth_return_future def facebook_request(self, path, callback, access_token=None, diff --git a/tornado/concurrent.py b/tornado/concurrent.py index 379e3eec6..bb1fb9480 100644 --- a/tornado/concurrent.py +++ b/tornado/concurrent.py @@ -377,7 +377,7 @@ class DummyExecutor(object): def submit(self, fn, *args, **kwargs): future = Future() try: - future.set_result(fn(*args, **kwargs)) + future_set_result_unless_cancelled(future, fn(*args, **kwargs)) except Exception: future_set_exc_info(future, sys.exc_info()) return future @@ -479,7 +479,7 @@ def return_future(f): def wrapper(*args, **kwargs): future = Future() callback, args, kwargs = replacer.replace( - lambda value=_NO_RESULT: future.set_result(value), + lambda value=_NO_RESULT: future_set_result_unless_cancelled(future, value), args, kwargs) def handle_error(typ, value, tb): @@ -547,6 +547,18 @@ def chain_future(a, b): IOLoop.current().add_future(a, copy) +def future_set_result_unless_cancelled(future, value): + """Set the given ``value`` as the `Future`'s result, if not cancelled. + + Avoids asyncio.InvalidStateError when calling set_result() on + a cancelled `asyncio.Future`. + + .. versionadded:: 5.0 + """ + if not future.cancelled(): + future.set_result(value) + + def future_set_exc_info(future, exc_info): """Set the given ``exc_info`` as the `Future`'s exception. diff --git a/tornado/gen.py b/tornado/gen.py index 7371ca517..ff23110a7 100644 --- a/tornado/gen.py +++ b/tornado/gen.py @@ -85,7 +85,8 @@ import textwrap import types import weakref -from tornado.concurrent import Future, is_future, chain_future, future_set_exc_info, future_add_done_callback +from tornado.concurrent import (Future, is_future, chain_future, future_set_exc_info, + future_add_done_callback, future_set_result_unless_cancelled) from tornado.ioloop import IOLoop from tornado.log import app_log from tornado import stack_context @@ -327,7 +328,7 @@ def _make_coroutine_wrapper(func, replace_callback): 'stack_context inconsistency (probably caused ' 'by yield within a "with StackContext" block)')) except (StopIteration, Return) as e: - future.set_result(_value_from_stopiteration(e)) + future_set_result_unless_cancelled(future, _value_from_stopiteration(e)) except Exception: future_set_exc_info(future, sys.exc_info()) else: @@ -345,7 +346,7 @@ def _make_coroutine_wrapper(func, replace_callback): # used in the absence of cycles). We can avoid the # cycle by clearing the local variable after we return it. future = None - future.set_result(result) + future_set_result_unless_cancelled(future, result) return future wrapper.__wrapped__ = wrapped @@ -631,7 +632,7 @@ def Task(func, *args, **kwargs): def set_result(result): if future.done(): return - future.set_result(result) + future_set_result_unless_cancelled(future, result) with stack_context.ExceptionStackContext(handle_exception): func(*args, callback=_argument_adapter(set_result), **kwargs) return future @@ -831,7 +832,8 @@ def multi_future(children, quiet_exceptions=()): future = _create_future() if not children: - future.set_result({} if keys is not None else []) + future_set_result_unless_cancelled(future, + {} if keys is not None else []) def callback(f): unfinished_children.remove(f) @@ -849,9 +851,10 @@ def multi_future(children, quiet_exceptions=()): future_set_exc_info(future, sys.exc_info()) if not future.done(): if keys is not None: - future.set_result(dict(zip(keys, result_list))) + future_set_result_unless_cancelled(future, + dict(zip(keys, result_list))) else: - future.set_result(result_list) + future_set_result_unless_cancelled(future, result_list) listening = set() for f in children: @@ -962,7 +965,8 @@ def sleep(duration): .. versionadded:: 4.1 """ f = _create_future() - IOLoop.current().call_later(duration, lambda: f.set_result(None)) + IOLoop.current().call_later(duration, + lambda: future_set_result_unless_cancelled(f, None)) return f @@ -1038,7 +1042,8 @@ class Runner(object): self.results[key] = result if self.yield_point is not None and self.yield_point.is_ready(): try: - self.future.set_result(self.yield_point.get_result()) + future_set_result_unless_cancelled(self.future, + self.yield_point.get_result()) except: future_set_exc_info(self.future, sys.exc_info()) self.yield_point = None @@ -1099,7 +1104,8 @@ class Runner(object): raise LeakedCallbackError( "finished without waiting for callbacks %r" % self.pending_callbacks) - self.result_future.set_result(_value_from_stopiteration(e)) + future_set_result_unless_cancelled(self.result_future, + _value_from_stopiteration(e)) self.result_future = None self._deactivate_stack_context() return @@ -1131,7 +1137,7 @@ class Runner(object): try: yielded.start(self) if yielded.is_ready(): - self.future.set_result( + future_set_result_unless_cancelled(self.future, yielded.get_result()) else: self.yield_point = yielded diff --git a/tornado/http1connection.py b/tornado/http1connection.py index 39d776c75..de8d8bf38 100644 --- a/tornado/http1connection.py +++ b/tornado/http1connection.py @@ -23,7 +23,8 @@ from __future__ import absolute_import, division, print_function import re -from tornado.concurrent import Future, future_add_done_callback +from tornado.concurrent import (Future, future_add_done_callback, + future_set_result_unless_cancelled) from tornado.escape import native_str, utf8 from tornado import gen from tornado import httputil @@ -291,7 +292,7 @@ class HTTP1Connection(httputil.HTTPConnection): self._close_callback = None callback() if not self._finish_future.done(): - self._finish_future.set_result(None) + future_set_result_unless_cancelled(self._finish_future, None) self._clear_callbacks() def close(self): @@ -299,7 +300,7 @@ class HTTP1Connection(httputil.HTTPConnection): self.stream.close() self._clear_callbacks() if not self._finish_future.done(): - self._finish_future.set_result(None) + future_set_result_unless_cancelled(self._finish_future, None) def detach(self): """Take control of the underlying stream. @@ -313,7 +314,7 @@ class HTTP1Connection(httputil.HTTPConnection): stream = self.stream self.stream = None if not self._finish_future.done(): - self._finish_future.set_result(None) + future_set_result_unless_cancelled(self._finish_future, None) return stream def set_body_timeout(self, timeout): @@ -483,7 +484,7 @@ class HTTP1Connection(httputil.HTTPConnection): if self._write_future is not None: future = self._write_future self._write_future = None - future.set_result(None) + future_set_result_unless_cancelled(future, None) def _can_keep_alive(self, start_line, headers): if self.params.no_keep_alive: @@ -510,7 +511,7 @@ class HTTP1Connection(httputil.HTTPConnection): # default state for the next request. self.stream.set_nodelay(False) if not self._finish_future.done(): - self._finish_future.set_result(None) + future_set_result_unless_cancelled(self._finish_future, None) def _parse_headers(self, data): # The lstrip removes newlines that some implementations sometimes diff --git a/tornado/httpclient.py b/tornado/httpclient.py index 1dea202e5..304e71097 100644 --- a/tornado/httpclient.py +++ b/tornado/httpclient.py @@ -44,7 +44,7 @@ import functools import time import weakref -from tornado.concurrent import Future +from tornado.concurrent import Future, future_set_result_unless_cancelled from tornado.escape import utf8, native_str from tornado import gen, httputil, stack_context from tornado.ioloop import IOLoop @@ -259,7 +259,7 @@ class AsyncHTTPClient(Configurable): if raise_error and response.error: future.set_exception(response.error) else: - future.set_result(response) + future_set_result_unless_cancelled(future, response) self.fetch_impl(request, handle_response) return future diff --git a/tornado/locks.py b/tornado/locks.py index 822c74421..fd0ad6af7 100644 --- a/tornado/locks.py +++ b/tornado/locks.py @@ -17,7 +17,7 @@ from __future__ import absolute_import, division, print_function import collections from tornado import gen, ioloop -from tornado.concurrent import Future +from tornado.concurrent import Future, future_set_result_unless_cancelled __all__ = ['Condition', 'Event', 'Semaphore', 'BoundedSemaphore', 'Lock'] @@ -129,7 +129,7 @@ class Condition(_TimeoutGarbageCollector): if timeout: def on_timeout(): if not waiter.done(): - waiter.set_result(False) + future_set_result_unless_cancelled(waiter, False) self._garbage_collect() io_loop = ioloop.IOLoop.current() timeout_handle = io_loop.add_timeout(timeout, on_timeout) @@ -147,7 +147,7 @@ class Condition(_TimeoutGarbageCollector): waiters.append(waiter) for waiter in waiters: - waiter.set_result(True) + future_set_result_unless_cancelled(waiter, True) def notify_all(self): """Wake all waiters.""" diff --git a/tornado/process.py b/tornado/process.py index 594913c6b..6da398497 100644 --- a/tornado/process.py +++ b/tornado/process.py @@ -29,7 +29,7 @@ import time from binascii import hexlify -from tornado.concurrent import Future +from tornado.concurrent import Future, future_set_result_unless_cancelled from tornado import ioloop from tornado.iostream import PipeIOStream from tornado.log import gen_log @@ -296,7 +296,7 @@ class Subprocess(object): # Unfortunately we don't have the original args any more. future.set_exception(CalledProcessError(ret, None)) else: - future.set_result(ret) + future_set_result_unless_cancelled(future, ret) self.set_exit_callback(callback) return future diff --git a/tornado/queues.py b/tornado/queues.py index 318c40933..cba0aa7f2 100644 --- a/tornado/queues.py +++ b/tornado/queues.py @@ -28,7 +28,7 @@ import collections import heapq from tornado import gen, ioloop -from tornado.concurrent import Future +from tornado.concurrent import Future, future_set_result_unless_cancelled from tornado.locks import Event __all__ = ['Queue', 'PriorityQueue', 'LifoQueue', 'QueueFull', 'QueueEmpty'] @@ -195,7 +195,7 @@ class Queue(object): assert self.empty(), "queue non-empty, why are getters waiting?" getter = self._getters.popleft() self.__put_internal(item) - getter.set_result(self._get()) + future_set_result_unless_cancelled(getter, self._get()) elif self.full(): raise QueueFull else: @@ -231,7 +231,7 @@ class Queue(object): assert self.full(), "queue not full, why are putters waiting?" item, putter = self._putters.popleft() self.__put_internal(item) - putter.set_result(None) + future_set_result_unless_cancelled(putter, None) return self._get() elif self.qsize(): return self._get() diff --git a/tornado/test/concurrent_test.py b/tornado/test/concurrent_test.py index 2b8dc3655..955f410a3 100644 --- a/tornado/test/concurrent_test.py +++ b/tornado/test/concurrent_test.py @@ -22,7 +22,8 @@ import socket import sys import traceback -from tornado.concurrent import Future, return_future, ReturnValueIgnoredError, run_on_executor +from tornado.concurrent import (Future, return_future, ReturnValueIgnoredError, + run_on_executor, future_set_result_unless_cancelled) from tornado.escape import utf8, to_unicode from tornado import gen from tornado.ioloop import IOLoop @@ -40,6 +41,23 @@ except ImportError: futures = None +class MiscFutureTest(AsyncTestCase): + + def test_future_set_result_unless_cancelled(self): + fut = Future() + future_set_result_unless_cancelled(fut, 42) + self.assertEqual(fut.result(), 42) + self.assertFalse(fut.cancelled()) + + fut = Future() + fut.cancel() + is_cancelled = fut.cancelled() + future_set_result_unless_cancelled(fut, 42) + self.assertEqual(fut.cancelled(), is_cancelled) + if not is_cancelled: + self.assertEqual(fut.result(), 42) + + class ReturnFutureTest(AsyncTestCase): @return_future def sync_future(self, callback): diff --git a/tornado/web.py b/tornado/web.py index f2749f7c4..c7510a140 100644 --- a/tornado/web.py +++ b/tornado/web.py @@ -80,7 +80,7 @@ import types from inspect import isclass from io import BytesIO -from tornado.concurrent import Future +from tornado.concurrent import Future, future_set_result_unless_cancelled from tornado import escape from tornado import gen from tornado import httputil @@ -1512,7 +1512,7 @@ class RequestHandler(object): if self._prepared_future is not None: # Tell the Application we've finished with prepare() # and are ready for the body to arrive. - self._prepared_future.set_result(None) + future_set_result_unless_cancelled(self._prepared_future, None) if self._finished: return @@ -2109,7 +2109,7 @@ class _HandlerDelegate(httputil.HTTPMessageDelegate): def finish(self): if self.stream_request_body: - self.request.body.set_result(None) + future_set_result_unless_cancelled(self.request.body, None) else: self.request.body = b''.join(self.chunks) self.request._parse_body() diff --git a/tornado/websocket.py b/tornado/websocket.py index 1cc750d9b..d24a05f16 100644 --- a/tornado/websocket.py +++ b/tornado/websocket.py @@ -28,7 +28,7 @@ import tornado.escape import tornado.web import zlib -from tornado.concurrent import Future +from tornado.concurrent import Future, future_set_result_unless_cancelled from tornado.escape import utf8, native_str, to_unicode from tornado import gen, httpclient, httputil from tornado.ioloop import IOLoop, PeriodicCallback @@ -1140,7 +1140,7 @@ class WebSocketClientConnection(simple_httpclient._HTTPConnection): # ability to see exceptions. self.final_callback = None - self.connect_future.set_result(self) + future_set_result_unless_cancelled(self.connect_future, self) def write_message(self, message, binary=False): """Sends a message to the WebSocket server.""" @@ -1160,7 +1160,7 @@ class WebSocketClientConnection(simple_httpclient._HTTPConnection): assert self.read_future is None future = Future() if self.read_queue: - future.set_result(self.read_queue.popleft()) + future_set_result_unless_cancelled(future, self.read_queue.popleft()) else: self.read_future = future if callback is not None: @@ -1171,7 +1171,7 @@ class WebSocketClientConnection(simple_httpclient._HTTPConnection): if self._on_message_callback: self._on_message_callback(message) elif self.read_future is not None: - self.read_future.set_result(message) + future_set_result_unless_cancelled(self.read_future, message) self.read_future = None else: self.read_queue.append(message)