]> git.ipfire.org Git - thirdparty/tornado.git/commitdiff
Issue #2188: check whether a future was cancelled before calling set_result() 2200/head
authorAntoine Pitrou <antoine@python.org>
Thu, 16 Nov 2017 17:28:20 +0000 (18:28 +0100)
committerAntoine Pitrou <antoine@python.org>
Thu, 16 Nov 2017 17:28:20 +0000 (18:28 +0100)
tornado/auth.py
tornado/concurrent.py
tornado/gen.py
tornado/http1connection.py
tornado/httpclient.py
tornado/locks.py
tornado/process.py
tornado/queues.py
tornado/test/concurrent_test.py
tornado/web.py
tornado/websocket.py

index f6d505a20711969d77ea8919dfa2703697699e57..caae0ef8d3c5323b78bfefbaf891031ea5194171 100644 (file)
@@ -75,7 +75,9 @@ import hmac
 import time
 import uuid
 
-from tornado.concurrent import Future, return_future, chain_future, future_set_exc_info
+from tornado.concurrent import (Future, return_future, chain_future,
+                                future_set_exc_info,
+                                future_set_result_unless_cancelled)
 from tornado import gen
 from tornado import httpclient
 from tornado import escape
@@ -295,7 +297,7 @@ class OpenIdMixin(object):
         claimed_id = self.get_argument("openid.claimed_id", None)
         if claimed_id:
             user["claimed_id"] = claimed_id
-        future.set_result(user)
+        future_set_result_unless_cancelled(future, user)
 
     def get_auth_http_client(self):
         """Returns the `.AsyncHTTPClient` instance to be used for auth requests.
@@ -519,7 +521,7 @@ class OAuthMixin(object):
             future.set_exception(AuthError("Error getting user"))
             return
         user["access_token"] = access_token
-        future.set_result(user)
+        future_set_result_unless_cancelled(future, user)
 
     def _oauth_request_parameters(self, url, access_token, parameters={},
                                   method="GET"):
@@ -668,7 +670,7 @@ class OAuth2Mixin(object):
                                            (response.error, response.request.url)))
             return
 
-        future.set_result(escape.json_decode(response.body))
+        future_set_result_unless_cancelled(future, escape.json_decode(response.body))
 
     def get_auth_http_client(self):
         """Returns the `.AsyncHTTPClient` instance to be used for auth requests.
@@ -811,7 +813,7 @@ class TwitterMixin(OAuthMixin):
                 "Error response %s fetching %s" % (response.error,
                                                    response.request.url)))
             return
-        future.set_result(escape.json_decode(response.body))
+        future_set_result_unless_cancelled(future, escape.json_decode(response.body))
 
     def _oauth_consumer_token(self):
         self.require_setting("twitter_consumer_key", "Twitter OAuth")
@@ -915,7 +917,7 @@ class GoogleOAuth2Mixin(OAuth2Mixin):
             return
 
         args = escape.json_decode(response.body)
-        future.set_result(args)
+        future_set_result_unless_cancelled(future, args)
 
 
 class FacebookGraphMixin(OAuth2Mixin):
@@ -1011,7 +1013,7 @@ class FacebookGraphMixin(OAuth2Mixin):
 
     def _on_get_user_info(self, future, session, fields, user):
         if user is None:
-            future.set_result(None)
+            future_set_result_unless_cancelled(future, None)
             return
 
         fieldmap = {}
@@ -1024,7 +1026,7 @@ class FacebookGraphMixin(OAuth2Mixin):
         # This should change in Tornado 5.0.
         fieldmap.update({"access_token": session["access_token"],
                          "session_expires": str(session.get("expires_in"))})
-        future.set_result(fieldmap)
+        future_set_result_unless_cancelled(future, fieldmap)
 
     @_auth_return_future
     def facebook_request(self, path, callback, access_token=None,
index 379e3eec6f4d2e6c37bb2566c782fe54419fd8fe..bb1fb94803f2413c5087f34b645daab9b73c0da9 100644 (file)
@@ -377,7 +377,7 @@ class DummyExecutor(object):
     def submit(self, fn, *args, **kwargs):
         future = Future()
         try:
-            future.set_result(fn(*args, **kwargs))
+            future_set_result_unless_cancelled(future, fn(*args, **kwargs))
         except Exception:
             future_set_exc_info(future, sys.exc_info())
         return future
@@ -479,7 +479,7 @@ def return_future(f):
     def wrapper(*args, **kwargs):
         future = Future()
         callback, args, kwargs = replacer.replace(
-            lambda value=_NO_RESULT: future.set_result(value),
+            lambda value=_NO_RESULT: future_set_result_unless_cancelled(future, value),
             args, kwargs)
 
         def handle_error(typ, value, tb):
@@ -547,6 +547,18 @@ def chain_future(a, b):
         IOLoop.current().add_future(a, copy)
 
 
+def future_set_result_unless_cancelled(future, value):
+    """Set the given ``value`` as the `Future`'s result, if not cancelled.
+
+    Avoids asyncio.InvalidStateError when calling set_result() on
+    a cancelled `asyncio.Future`.
+
+    .. versionadded:: 5.0
+    """
+    if not future.cancelled():
+        future.set_result(value)
+
+
 def future_set_exc_info(future, exc_info):
     """Set the given ``exc_info`` as the `Future`'s exception.
 
index 7371ca517cb9d7635636b556134c44e87e2ad4ce..ff23110a799066f8349aeb54bc4f55d15e2782ed 100644 (file)
@@ -85,7 +85,8 @@ import textwrap
 import types
 import weakref
 
-from tornado.concurrent import Future, is_future, chain_future, future_set_exc_info, future_add_done_callback
+from tornado.concurrent import (Future, is_future, chain_future, future_set_exc_info,
+                                future_add_done_callback, future_set_result_unless_cancelled)
 from tornado.ioloop import IOLoop
 from tornado.log import app_log
 from tornado import stack_context
@@ -327,7 +328,7 @@ def _make_coroutine_wrapper(func, replace_callback):
                                 'stack_context inconsistency (probably caused '
                                 'by yield within a "with StackContext" block)'))
                 except (StopIteration, Return) as e:
-                    future.set_result(_value_from_stopiteration(e))
+                    future_set_result_unless_cancelled(future, _value_from_stopiteration(e))
                 except Exception:
                     future_set_exc_info(future, sys.exc_info())
                 else:
@@ -345,7 +346,7 @@ def _make_coroutine_wrapper(func, replace_callback):
                     # used in the absence of cycles).  We can avoid the
                     # cycle by clearing the local variable after we return it.
                     future = None
-        future.set_result(result)
+        future_set_result_unless_cancelled(future, result)
         return future
 
     wrapper.__wrapped__ = wrapped
@@ -631,7 +632,7 @@ def Task(func, *args, **kwargs):
     def set_result(result):
         if future.done():
             return
-        future.set_result(result)
+        future_set_result_unless_cancelled(future, result)
     with stack_context.ExceptionStackContext(handle_exception):
         func(*args, callback=_argument_adapter(set_result), **kwargs)
     return future
@@ -831,7 +832,8 @@ def multi_future(children, quiet_exceptions=()):
 
     future = _create_future()
     if not children:
-        future.set_result({} if keys is not None else [])
+        future_set_result_unless_cancelled(future,
+                                           {} if keys is not None else [])
 
     def callback(f):
         unfinished_children.remove(f)
@@ -849,9 +851,10 @@ def multi_future(children, quiet_exceptions=()):
                         future_set_exc_info(future, sys.exc_info())
             if not future.done():
                 if keys is not None:
-                    future.set_result(dict(zip(keys, result_list)))
+                    future_set_result_unless_cancelled(future,
+                                                       dict(zip(keys, result_list)))
                 else:
-                    future.set_result(result_list)
+                    future_set_result_unless_cancelled(future, result_list)
 
     listening = set()
     for f in children:
@@ -962,7 +965,8 @@ def sleep(duration):
     .. versionadded:: 4.1
     """
     f = _create_future()
-    IOLoop.current().call_later(duration, lambda: f.set_result(None))
+    IOLoop.current().call_later(duration,
+                                lambda: future_set_result_unless_cancelled(f, None))
     return f
 
 
@@ -1038,7 +1042,8 @@ class Runner(object):
         self.results[key] = result
         if self.yield_point is not None and self.yield_point.is_ready():
             try:
-                self.future.set_result(self.yield_point.get_result())
+                future_set_result_unless_cancelled(self.future,
+                                                   self.yield_point.get_result())
             except:
                 future_set_exc_info(self.future, sys.exc_info())
             self.yield_point = None
@@ -1099,7 +1104,8 @@ class Runner(object):
                         raise LeakedCallbackError(
                             "finished without waiting for callbacks %r" %
                             self.pending_callbacks)
-                    self.result_future.set_result(_value_from_stopiteration(e))
+                    future_set_result_unless_cancelled(self.result_future,
+                                                       _value_from_stopiteration(e))
                     self.result_future = None
                     self._deactivate_stack_context()
                     return
@@ -1131,7 +1137,7 @@ class Runner(object):
                 try:
                     yielded.start(self)
                     if yielded.is_ready():
-                        self.future.set_result(
+                        future_set_result_unless_cancelled(self.future,
                             yielded.get_result())
                     else:
                         self.yield_point = yielded
index 39d776c75e328dbc2361b7227cc9123a85d16c33..de8d8bf383ee6a440f3aa1306e85f19e19f34e5f 100644 (file)
@@ -23,7 +23,8 @@ from __future__ import absolute_import, division, print_function
 
 import re
 
-from tornado.concurrent import Future, future_add_done_callback
+from tornado.concurrent import (Future, future_add_done_callback,
+                                future_set_result_unless_cancelled)
 from tornado.escape import native_str, utf8
 from tornado import gen
 from tornado import httputil
@@ -291,7 +292,7 @@ class HTTP1Connection(httputil.HTTPConnection):
             self._close_callback = None
             callback()
         if not self._finish_future.done():
-            self._finish_future.set_result(None)
+            future_set_result_unless_cancelled(self._finish_future, None)
         self._clear_callbacks()
 
     def close(self):
@@ -299,7 +300,7 @@ class HTTP1Connection(httputil.HTTPConnection):
             self.stream.close()
         self._clear_callbacks()
         if not self._finish_future.done():
-            self._finish_future.set_result(None)
+            future_set_result_unless_cancelled(self._finish_future, None)
 
     def detach(self):
         """Take control of the underlying stream.
@@ -313,7 +314,7 @@ class HTTP1Connection(httputil.HTTPConnection):
         stream = self.stream
         self.stream = None
         if not self._finish_future.done():
-            self._finish_future.set_result(None)
+            future_set_result_unless_cancelled(self._finish_future, None)
         return stream
 
     def set_body_timeout(self, timeout):
@@ -483,7 +484,7 @@ class HTTP1Connection(httputil.HTTPConnection):
         if self._write_future is not None:
             future = self._write_future
             self._write_future = None
-            future.set_result(None)
+            future_set_result_unless_cancelled(future, None)
 
     def _can_keep_alive(self, start_line, headers):
         if self.params.no_keep_alive:
@@ -510,7 +511,7 @@ class HTTP1Connection(httputil.HTTPConnection):
         # default state for the next request.
         self.stream.set_nodelay(False)
         if not self._finish_future.done():
-            self._finish_future.set_result(None)
+            future_set_result_unless_cancelled(self._finish_future, None)
 
     def _parse_headers(self, data):
         # The lstrip removes newlines that some implementations sometimes
index 1dea202e558e3cc885073a1ad02ba95393a8b725..304e710976ab31179a87a96059c2c3c0a1942900 100644 (file)
@@ -44,7 +44,7 @@ import functools
 import time
 import weakref
 
-from tornado.concurrent import Future
+from tornado.concurrent import Future, future_set_result_unless_cancelled
 from tornado.escape import utf8, native_str
 from tornado import gen, httputil, stack_context
 from tornado.ioloop import IOLoop
@@ -259,7 +259,7 @@ class AsyncHTTPClient(Configurable):
             if raise_error and response.error:
                 future.set_exception(response.error)
             else:
-                future.set_result(response)
+                future_set_result_unless_cancelled(future, response)
         self.fetch_impl(request, handle_response)
         return future
 
index 822c74421e4aff2b41327cfe558e88954e64d147..fd0ad6af77db84db391523cf89e71cb977e550f5 100644 (file)
@@ -17,7 +17,7 @@ from __future__ import absolute_import, division, print_function
 import collections
 
 from tornado import gen, ioloop
-from tornado.concurrent import Future
+from tornado.concurrent import Future, future_set_result_unless_cancelled
 
 __all__ = ['Condition', 'Event', 'Semaphore', 'BoundedSemaphore', 'Lock']
 
@@ -129,7 +129,7 @@ class Condition(_TimeoutGarbageCollector):
         if timeout:
             def on_timeout():
                 if not waiter.done():
-                    waiter.set_result(False)
+                    future_set_result_unless_cancelled(waiter, False)
                 self._garbage_collect()
             io_loop = ioloop.IOLoop.current()
             timeout_handle = io_loop.add_timeout(timeout, on_timeout)
@@ -147,7 +147,7 @@ class Condition(_TimeoutGarbageCollector):
                 waiters.append(waiter)
 
         for waiter in waiters:
-            waiter.set_result(True)
+            future_set_result_unless_cancelled(waiter, True)
 
     def notify_all(self):
         """Wake all waiters."""
index 594913c6b4369f31585f6a7f13ed1f85e79ebf9f..6da398497d9156af9126308a621001dd3719e3b2 100644 (file)
@@ -29,7 +29,7 @@ import time
 
 from binascii import hexlify
 
-from tornado.concurrent import Future
+from tornado.concurrent import Future, future_set_result_unless_cancelled
 from tornado import ioloop
 from tornado.iostream import PipeIOStream
 from tornado.log import gen_log
@@ -296,7 +296,7 @@ class Subprocess(object):
                 # Unfortunately we don't have the original args any more.
                 future.set_exception(CalledProcessError(ret, None))
             else:
-                future.set_result(ret)
+                future_set_result_unless_cancelled(future, ret)
         self.set_exit_callback(callback)
         return future
 
index 318c4093355fe17feda4a39298f9569a692b4a0c..cba0aa7f29aafd706aa327fc194ac370d5b5fbc9 100644 (file)
@@ -28,7 +28,7 @@ import collections
 import heapq
 
 from tornado import gen, ioloop
-from tornado.concurrent import Future
+from tornado.concurrent import Future, future_set_result_unless_cancelled
 from tornado.locks import Event
 
 __all__ = ['Queue', 'PriorityQueue', 'LifoQueue', 'QueueFull', 'QueueEmpty']
@@ -195,7 +195,7 @@ class Queue(object):
             assert self.empty(), "queue non-empty, why are getters waiting?"
             getter = self._getters.popleft()
             self.__put_internal(item)
-            getter.set_result(self._get())
+            future_set_result_unless_cancelled(getter, self._get())
         elif self.full():
             raise QueueFull
         else:
@@ -231,7 +231,7 @@ class Queue(object):
             assert self.full(), "queue not full, why are putters waiting?"
             item, putter = self._putters.popleft()
             self.__put_internal(item)
-            putter.set_result(None)
+            future_set_result_unless_cancelled(putter, None)
             return self._get()
         elif self.qsize():
             return self._get()
index 2b8dc3655022fddd57f4a0f1dbd86659005e0ded..955f410a3879e3bbdd37afeea832f5ba399acf19 100644 (file)
@@ -22,7 +22,8 @@ import socket
 import sys
 import traceback
 
-from tornado.concurrent import Future, return_future, ReturnValueIgnoredError, run_on_executor
+from tornado.concurrent import (Future, return_future, ReturnValueIgnoredError,
+                                run_on_executor, future_set_result_unless_cancelled)
 from tornado.escape import utf8, to_unicode
 from tornado import gen
 from tornado.ioloop import IOLoop
@@ -40,6 +41,23 @@ except ImportError:
     futures = None
 
 
+class MiscFutureTest(AsyncTestCase):
+
+    def test_future_set_result_unless_cancelled(self):
+        fut = Future()
+        future_set_result_unless_cancelled(fut, 42)
+        self.assertEqual(fut.result(), 42)
+        self.assertFalse(fut.cancelled())
+
+        fut = Future()
+        fut.cancel()
+        is_cancelled = fut.cancelled()
+        future_set_result_unless_cancelled(fut, 42)
+        self.assertEqual(fut.cancelled(), is_cancelled)
+        if not is_cancelled:
+            self.assertEqual(fut.result(), 42)
+
+
 class ReturnFutureTest(AsyncTestCase):
     @return_future
     def sync_future(self, callback):
index f2749f7c416a6f14fc6c55dd3968d20c7e438adc..c7510a140fec7eef477785a033f2e73d0d3a7d0d 100644 (file)
@@ -80,7 +80,7 @@ import types
 from inspect import isclass
 from io import BytesIO
 
-from tornado.concurrent import Future
+from tornado.concurrent import Future, future_set_result_unless_cancelled
 from tornado import escape
 from tornado import gen
 from tornado import httputil
@@ -1512,7 +1512,7 @@ class RequestHandler(object):
             if self._prepared_future is not None:
                 # Tell the Application we've finished with prepare()
                 # and are ready for the body to arrive.
-                self._prepared_future.set_result(None)
+                future_set_result_unless_cancelled(self._prepared_future, None)
             if self._finished:
                 return
 
@@ -2109,7 +2109,7 @@ class _HandlerDelegate(httputil.HTTPMessageDelegate):
 
     def finish(self):
         if self.stream_request_body:
-            self.request.body.set_result(None)
+            future_set_result_unless_cancelled(self.request.body, None)
         else:
             self.request.body = b''.join(self.chunks)
             self.request._parse_body()
index 1cc750d9b21b4005e7f1dc786e0cb57f349130bb..d24a05f16bb94209652d1bf9ee1c654763fdaecd 100644 (file)
@@ -28,7 +28,7 @@ import tornado.escape
 import tornado.web
 import zlib
 
-from tornado.concurrent import Future
+from tornado.concurrent import Future, future_set_result_unless_cancelled
 from tornado.escape import utf8, native_str, to_unicode
 from tornado import gen, httpclient, httputil
 from tornado.ioloop import IOLoop, PeriodicCallback
@@ -1140,7 +1140,7 @@ class WebSocketClientConnection(simple_httpclient._HTTPConnection):
         # ability to see exceptions.
         self.final_callback = None
 
-        self.connect_future.set_result(self)
+        future_set_result_unless_cancelled(self.connect_future, self)
 
     def write_message(self, message, binary=False):
         """Sends a message to the WebSocket server."""
@@ -1160,7 +1160,7 @@ class WebSocketClientConnection(simple_httpclient._HTTPConnection):
         assert self.read_future is None
         future = Future()
         if self.read_queue:
-            future.set_result(self.read_queue.popleft())
+            future_set_result_unless_cancelled(future, self.read_queue.popleft())
         else:
             self.read_future = future
         if callback is not None:
@@ -1171,7 +1171,7 @@ class WebSocketClientConnection(simple_httpclient._HTTPConnection):
         if self._on_message_callback:
             self._on_message_callback(message)
         elif self.read_future is not None:
-            self.read_future.set_result(message)
+            future_set_result_unless_cancelled(self.read_future, message)
             self.read_future = None
         else:
             self.read_queue.append(message)