]> git.ipfire.org Git - thirdparty/tornado.git/commitdiff
add_callback now takes *args, **kwargs.
authorBen Darnell <ben@bendarnell.com>
Sat, 8 Dec 2012 23:25:11 +0000 (18:25 -0500)
committerBen Darnell <ben@bendarnell.com>
Sat, 8 Dec 2012 23:25:11 +0000 (18:25 -0500)
This reduces the need for functools.partial or lambda wrappers, and
works better with stack_context in some cases since binding the
arguments within IOLoop lets it see whether the function is already
wrapped.

tornado/ioloop.py
tornado/platform/twisted.py
tornado/test/ioloop_test.py
tornado/test/stack_context_test.py
website/sphinx/releases/next.rst

index 1db5e5a93d8d8f6c26c673f78df840bcc32996ca..8567ab8020173f8ed1bc9ac73858fb7dd74e1821 100644 (file)
@@ -320,7 +320,7 @@ class IOLoop(Configurable):
         """
         raise NotImplementedError()
 
-    def add_callback(self, callback):
+    def add_callback(self, callback, *args, **kwargs):
         """Calls the given callback on the next I/O loop iteration.
 
         It is safe to call this method from any thread at any time,
@@ -335,7 +335,7 @@ class IOLoop(Configurable):
         """
         raise NotImplementedError()
 
-    def add_callback_from_signal(self, callback):
+    def add_callback_from_signal(self, callback, *args, **kwargs):
         """Calls the given callback on the next I/O loop iteration.
 
         Safe for use from a Python signal handler; should not be used
@@ -609,12 +609,13 @@ class PollIOLoop(IOLoop):
         # collection pass whenever there are too many dead timeouts.
         timeout.callback = None
 
-    def add_callback(self, callback):
+    def add_callback(self, callback, *args, **kwargs):
         with self._callback_lock:
             if self._closing:
                 raise RuntimeError("IOLoop is closing")
             list_empty = not self._callbacks
-            self._callbacks.append(stack_context.wrap(callback))
+            self._callbacks.append(functools.partial(
+                    stack_context.wrap(callback), *args, **kwargs))
         if list_empty and thread.get_ident() != self._thread_ident:
             # If we're in the IOLoop's thread, we know it's not currently
             # polling.  If we're not, and we added the first callback to an
@@ -624,12 +625,12 @@ class PollIOLoop(IOLoop):
             # avoid it when we can.
             self._waker.wake()
 
-    def add_callback_from_signal(self, callback):
+    def add_callback_from_signal(self, callback, *args, **kwargs):
         with stack_context.NullContext():
             if thread.get_ident() != self._thread_ident:
                 # if the signal is handled on another thread, we can add
                 # it normally (modulo the NullContext)
-                self.add_callback(callback)
+                self.add_callback(callback, *args, **kwargs)
             else:
                 # If we're on the IOLoop's thread, we cannot use
                 # the regular add_callback because it may deadlock on
@@ -639,7 +640,8 @@ class PollIOLoop(IOLoop):
                 # _callback_lock block in IOLoop.start, we may modify
                 # either the old or new version of self._callbacks,
                 # but either way will work.
-                self._callbacks.append(stack_context.wrap(callback))
+                self._callbacks.append(functools.partial(
+                        stack_context.wrap(callback), *args, **kwargs))
 
 
 class _Timeout(object):
index 4536837f30c6da62b7c2cddfcb415db727d37d0d..8aec9377585f62f030527f97062a996a861b6974 100644 (file)
@@ -458,8 +458,9 @@ class TwistedIOLoop(tornado.ioloop.IOLoop):
     def remove_timeout(self, timeout):
         timeout.cancel()
 
-    def add_callback(self, callback):
-        self.reactor.callFromThread(wrap(callback))
+    def add_callback(self, callback, *args, **kwargs):
+        self.reactor.callFromThread(functools.partial(wrap(callback),
+                                                      *args, **kwargs))
 
-    def add_callback_from_signal(self, callback):
-        self.add_callback(callback)
+    def add_callback_from_signal(self, callback, *args, **kwargs):
+        self.add_callback(callback, *args, **kwargs)
index fa403c0c769b0dc97d8ddd51bbd21ac56d7a6cf9..1f8e5401a44fc845f2dc5fa2f11d6de997458043 100644 (file)
@@ -2,12 +2,14 @@
 
 
 from __future__ import absolute_import, division, with_statement
+import contextlib
 import datetime
+import functools
 import threading
 import time
 
 from tornado.ioloop import IOLoop
-from tornado.stack_context import ExceptionStackContext
+from tornado.stack_context import ExceptionStackContext, StackContext, wrap
 from tornado.testing import AsyncTestCase, bind_unused_port
 from tornado.test.util import unittest
 
@@ -111,6 +113,63 @@ class TestIOLoop(AsyncTestCase):
                 self.assertEqual("IOLoop is closing", str(e))
                 break
 
+
+class TestIOLoopAddCallback(AsyncTestCase):
+    def setUp(self):
+        super(TestIOLoopAddCallback, self).setUp()
+        self.active_contexts = []
+
+    def add_callback(self, callback, *args, **kwargs):
+        self.io_loop.add_callback(callback, *args, **kwargs)
+
+    @contextlib.contextmanager
+    def context(self, name):
+        self.active_contexts.append(name)
+        yield
+        self.assertEqual(self.active_contexts.pop(), name)
+
+    def test_pre_wrap(self):
+        # A pre-wrapped callback is run in the context in which it was
+        # wrapped, not when it was added to the IOLoop.
+        def f1():
+            self.assertIn('c1', self.active_contexts)
+            self.assertNotIn('c2', self.active_contexts)
+            self.stop()
+
+        with StackContext(functools.partial(self.context, 'c1')):
+            wrapped = wrap(f1)
+
+        with StackContext(functools.partial(self.context, 'c2')):
+            self.add_callback(wrapped)
+
+        self.wait()
+
+    def test_pre_wrap_with_args(self):
+        # Same as test_pre_wrap, but the function takes arguments.
+        # Implementation note: The function must not be wrapped in a
+        # functools.partial until after it has been passed through
+        # stack_context.wrap
+        def f1(foo, bar):
+            self.assertIn('c1', self.active_contexts)
+            self.assertNotIn('c2', self.active_contexts)
+            self.stop((foo, bar))
+
+        with StackContext(functools.partial(self.context, 'c1')):
+            wrapped = wrap(f1)
+
+        with StackContext(functools.partial(self.context, 'c2')):
+            self.add_callback(wrapped, 1, bar=2)
+
+        result = self.wait()
+        self.assertEqual(result, (1, 2))
+
+
+class TestIOLoopAddCallbackFromSignal(TestIOLoopAddCallback):
+    # Repeat the add_callback tests using add_callback_from_signal
+    def add_callback(self, callback, *args, **kwargs):
+        self.io_loop.add_callback_from_signal(callback, *args, **kwargs)
+
+
 class TestIOLoopFutures(AsyncTestCase):
     def test_add_future_threads(self):
         with futures.ThreadPoolExecutor(1) as pool:
index 4bbba5213051560e1d483c3d9fd6b7f0b22e0a19..53c640f8ada71e08ec9feb657f7c0b1f95dd461a 100644 (file)
@@ -169,5 +169,6 @@ class StackContextTest(AsyncTestCase):
         self.io_loop.add_callback(f1)
         self.wait()
 
+
 if __name__ == '__main__':
     unittest.main()
index d784660e9412cc13d3929916fa4c6906a11ebbd7..e996d79fbe53b18ce3ce41d338e08f578bb5379d 100644 (file)
@@ -194,3 +194,5 @@ In progress
   that are passed to the ``get``/``post``/etc method.  These attributes
   are set before those methods are called, so they are available during
   ``prepare()``
+* `IOLoop.add_callback` and `add_callback_from_signal` now take
+  ``*args, **kwargs`` to pass along to the callback.