]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
Closes #21886, #21447: Fix a race condition in asyncio when setting the result
authorVictor Stinner <victor.stinner@gmail.com>
Sat, 5 Jul 2014 13:29:41 +0000 (15:29 +0200)
committerVictor Stinner <victor.stinner@gmail.com>
Sat, 5 Jul 2014 13:29:41 +0000 (15:29 +0200)
of a Future with call_soon(). Add an helper, a private method, to set the
result only if the future was not cancelled.

Lib/asyncio/coroutines.py
Lib/asyncio/futures.py
Lib/asyncio/proactor_events.py
Lib/asyncio/queues.py
Lib/asyncio/selector_events.py
Lib/asyncio/tasks.py
Lib/asyncio/unix_events.py
Lib/test/test_asyncio/test_futures.py
Lib/test/test_asyncio/test_tasks.py

index 71a1ec4dd0ea975877e2048dce69298cf6c1bca7..7654a0b9e0529bef4d687068f464005aaef49c3c 100644 (file)
@@ -64,6 +64,12 @@ class CoroWrapper:
         self.gen = gen
         self.func = func
         self._source_traceback = traceback.extract_stack(sys._getframe(1))
+        # __name__, __qualname__, __doc__ attributes are set by the coroutine()
+        # decorator
+
+    def __repr__(self):
+        return ('<%s %s>'
+                % (self.__class__.__name__, _format_coroutine(self)))
 
     def __iter__(self):
         return self
index fcc90d13718c3c2b030b37281c503a5d8aff7154..022fef76efe5fd75243f2e389c0a29fbbc41f90e 100644 (file)
@@ -316,6 +316,12 @@ class Future:
 
     # So-called internal methods (note: no set_running_or_notify_cancel()).
 
+    def _set_result_unless_cancelled(self, result):
+        """Helper setting the result only if the future was not cancelled."""
+        if self.cancelled():
+            return
+        self.set_result(result)
+
     def set_result(self, result):
         """Mark the future done and set its result.
 
index b76f69ee57107d0b07e5a63b53d6e4d1b0a1141d..a80876f366a40ab9b55f0600479880f85700144d 100644 (file)
@@ -38,7 +38,7 @@ class _ProactorBasePipeTransport(transports._FlowControlMixin,
             self._server.attach(self)
         self._loop.call_soon(self._protocol.connection_made, self)
         if waiter is not None:
-            self._loop.call_soon(waiter.set_result, None)
+            self._loop.call_soon(waiter._set_result_unless_cancelled, None)
 
     def _set_extra(self, sock):
         self._extra['pipe'] = sock
index 57afb053ee2710c405afcb3f5e40e8bdf5bf2e8e..41551a9022faa2aa4fe92a8ab4a4c707993e86e1 100644 (file)
@@ -173,7 +173,7 @@ class Queue:
             # run, we need to defer the put for a tick to ensure that
             # getters and putters alternate perfectly. See
             # ChannelTest.test_wait.
-            self._loop.call_soon(putter.set_result, None)
+            self._loop.call_soon(putter._set_result_unless_cancelled, None)
 
             return self._get()
 
index df64aece3ba0f3c9ed0edf0c518cfdf1ad475482..2a170340b9e90ed8c6f87c26893eceb74ad5051e 100644 (file)
@@ -481,7 +481,7 @@ class _SelectorSocketTransport(_SelectorTransport):
         self._loop.add_reader(self._sock_fd, self._read_ready)
         self._loop.call_soon(self._protocol.connection_made, self)
         if waiter is not None:
-            self._loop.call_soon(waiter.set_result, None)
+            self._loop.call_soon(waiter._set_result_unless_cancelled, None)
 
     def pause_reading(self):
         if self._closing:
@@ -690,7 +690,8 @@ class _SelectorSslTransport(_SelectorTransport):
         self._loop.add_reader(self._sock_fd, self._read_ready)
         self._loop.call_soon(self._protocol.connection_made, self)
         if self._waiter is not None:
-            self._loop.call_soon(self._waiter.set_result, None)
+            self._loop.call_soon(self._waiter._set_result_unless_cancelled,
+                                 None)
 
     def pause_reading(self):
         # XXX This is a bit icky, given the comment at the top of
index dd191e770be07799b48a079ff8ba2dd85e0709dd..8c7217b702b43adcee66c6861cffe9b259cf67f6 100644 (file)
@@ -487,7 +487,8 @@ def as_completed(fs, *, loop=None, timeout=None):
 def sleep(delay, result=None, *, loop=None):
     """Coroutine that completes after a given time (in seconds)."""
     future = futures.Future(loop=loop)
-    h = future._loop.call_later(delay, future.set_result, result)
+    h = future._loop.call_later(delay,
+                                future._set_result_unless_cancelled, result)
     try:
         return (yield from future)
     finally:
index 5f728b5728a21a881fc3d79da4f8511b0ea04fdb..535ea2209bbb1c503e520132417583331aa1ef85 100644 (file)
@@ -269,7 +269,7 @@ class _UnixReadPipeTransport(transports.ReadTransport):
         self._loop.add_reader(self._fileno, self._read_ready)
         self._loop.call_soon(self._protocol.connection_made, self)
         if waiter is not None:
-            self._loop.call_soon(waiter.set_result, None)
+            self._loop.call_soon(waiter._set_result_unless_cancelled, None)
 
     def _read_ready(self):
         try:
@@ -353,7 +353,7 @@ class _UnixWritePipeTransport(transports._FlowControlMixin,
 
         self._loop.call_soon(self._protocol.connection_made, self)
         if waiter is not None:
-            self._loop.call_soon(waiter.set_result, None)
+            self._loop.call_soon(waiter._set_result_unless_cancelled, None)
 
     def get_write_buffer_size(self):
         return sum(len(data) for data in self._buffer)
index 96b41d69db9cb116b2755f568851026bd6c06b1c..a6071ea76ba8baf1518a21c2d02b2582742285bc 100644 (file)
@@ -343,6 +343,12 @@ class FutureTests(test_utils.TestCase):
         message = m_log.error.call_args[0][0]
         self.assertRegex(message, re.compile(regex, re.DOTALL))
 
+    def test_set_result_unless_cancelled(self):
+        fut = asyncio.Future(loop=self.loop)
+        fut.cancel()
+        fut._set_result_unless_cancelled(2)
+        self.assertTrue(fut.cancelled())
+
 
 class FutureDoneCallbackTests(test_utils.TestCase):
 
index 83b7e61fdb985371f502ce24d4a7ced9864f4edf..eaef05b50dd81e26ddc28e643ffa390d8326bf21 100644 (file)
@@ -211,6 +211,10 @@ class TaskTests(test_utils.TestCase):
         coro = ('%s() at %s:%s'
                 % (coro_qualname, code.co_filename, code.co_firstlineno))
 
+        # test repr(CoroWrapper)
+        if coroutines._DEBUG:
+            self.assertEqual(repr(gen), '<CoroWrapper %s>' % coro)
+
         # test pending Task
         t = asyncio.Task(gen, loop=self.loop)
         t.add_done_callback(Dummy())