From: Ben Darnell Date: Sat, 8 Dec 2012 23:25:11 +0000 (-0500) Subject: add_callback now takes *args, **kwargs. X-Git-Tag: v3.0.0~201 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=ea79e8a9385c14f2b20601b5a9673b7a605c164c;p=thirdparty%2Ftornado.git add_callback now takes *args, **kwargs. This reduces the need for functools.partial or lambda wrappers, and works better with stack_context in some cases since binding the arguments within IOLoop lets it see whether the function is already wrapped. --- diff --git a/tornado/ioloop.py b/tornado/ioloop.py index 1db5e5a93..8567ab802 100644 --- a/tornado/ioloop.py +++ b/tornado/ioloop.py @@ -320,7 +320,7 @@ class IOLoop(Configurable): """ raise NotImplementedError() - def add_callback(self, callback): + def add_callback(self, callback, *args, **kwargs): """Calls the given callback on the next I/O loop iteration. It is safe to call this method from any thread at any time, @@ -335,7 +335,7 @@ class IOLoop(Configurable): """ raise NotImplementedError() - def add_callback_from_signal(self, callback): + def add_callback_from_signal(self, callback, *args, **kwargs): """Calls the given callback on the next I/O loop iteration. Safe for use from a Python signal handler; should not be used @@ -609,12 +609,13 @@ class PollIOLoop(IOLoop): # collection pass whenever there are too many dead timeouts. timeout.callback = None - def add_callback(self, callback): + def add_callback(self, callback, *args, **kwargs): with self._callback_lock: if self._closing: raise RuntimeError("IOLoop is closing") list_empty = not self._callbacks - self._callbacks.append(stack_context.wrap(callback)) + self._callbacks.append(functools.partial( + stack_context.wrap(callback), *args, **kwargs)) if list_empty and thread.get_ident() != self._thread_ident: # If we're in the IOLoop's thread, we know it's not currently # polling. If we're not, and we added the first callback to an @@ -624,12 +625,12 @@ class PollIOLoop(IOLoop): # avoid it when we can. self._waker.wake() - def add_callback_from_signal(self, callback): + def add_callback_from_signal(self, callback, *args, **kwargs): with stack_context.NullContext(): if thread.get_ident() != self._thread_ident: # if the signal is handled on another thread, we can add # it normally (modulo the NullContext) - self.add_callback(callback) + self.add_callback(callback, *args, **kwargs) else: # If we're on the IOLoop's thread, we cannot use # the regular add_callback because it may deadlock on @@ -639,7 +640,8 @@ class PollIOLoop(IOLoop): # _callback_lock block in IOLoop.start, we may modify # either the old or new version of self._callbacks, # but either way will work. - self._callbacks.append(stack_context.wrap(callback)) + self._callbacks.append(functools.partial( + stack_context.wrap(callback), *args, **kwargs)) class _Timeout(object): diff --git a/tornado/platform/twisted.py b/tornado/platform/twisted.py index 4536837f3..8aec93775 100644 --- a/tornado/platform/twisted.py +++ b/tornado/platform/twisted.py @@ -458,8 +458,9 @@ class TwistedIOLoop(tornado.ioloop.IOLoop): def remove_timeout(self, timeout): timeout.cancel() - def add_callback(self, callback): - self.reactor.callFromThread(wrap(callback)) + def add_callback(self, callback, *args, **kwargs): + self.reactor.callFromThread(functools.partial(wrap(callback), + *args, **kwargs)) - def add_callback_from_signal(self, callback): - self.add_callback(callback) + def add_callback_from_signal(self, callback, *args, **kwargs): + self.add_callback(callback, *args, **kwargs) diff --git a/tornado/test/ioloop_test.py b/tornado/test/ioloop_test.py index fa403c0c7..1f8e5401a 100644 --- a/tornado/test/ioloop_test.py +++ b/tornado/test/ioloop_test.py @@ -2,12 +2,14 @@ from __future__ import absolute_import, division, with_statement +import contextlib import datetime +import functools import threading import time from tornado.ioloop import IOLoop -from tornado.stack_context import ExceptionStackContext +from tornado.stack_context import ExceptionStackContext, StackContext, wrap from tornado.testing import AsyncTestCase, bind_unused_port from tornado.test.util import unittest @@ -111,6 +113,63 @@ class TestIOLoop(AsyncTestCase): self.assertEqual("IOLoop is closing", str(e)) break + +class TestIOLoopAddCallback(AsyncTestCase): + def setUp(self): + super(TestIOLoopAddCallback, self).setUp() + self.active_contexts = [] + + def add_callback(self, callback, *args, **kwargs): + self.io_loop.add_callback(callback, *args, **kwargs) + + @contextlib.contextmanager + def context(self, name): + self.active_contexts.append(name) + yield + self.assertEqual(self.active_contexts.pop(), name) + + def test_pre_wrap(self): + # A pre-wrapped callback is run in the context in which it was + # wrapped, not when it was added to the IOLoop. + def f1(): + self.assertIn('c1', self.active_contexts) + self.assertNotIn('c2', self.active_contexts) + self.stop() + + with StackContext(functools.partial(self.context, 'c1')): + wrapped = wrap(f1) + + with StackContext(functools.partial(self.context, 'c2')): + self.add_callback(wrapped) + + self.wait() + + def test_pre_wrap_with_args(self): + # Same as test_pre_wrap, but the function takes arguments. + # Implementation note: The function must not be wrapped in a + # functools.partial until after it has been passed through + # stack_context.wrap + def f1(foo, bar): + self.assertIn('c1', self.active_contexts) + self.assertNotIn('c2', self.active_contexts) + self.stop((foo, bar)) + + with StackContext(functools.partial(self.context, 'c1')): + wrapped = wrap(f1) + + with StackContext(functools.partial(self.context, 'c2')): + self.add_callback(wrapped, 1, bar=2) + + result = self.wait() + self.assertEqual(result, (1, 2)) + + +class TestIOLoopAddCallbackFromSignal(TestIOLoopAddCallback): + # Repeat the add_callback tests using add_callback_from_signal + def add_callback(self, callback, *args, **kwargs): + self.io_loop.add_callback_from_signal(callback, *args, **kwargs) + + class TestIOLoopFutures(AsyncTestCase): def test_add_future_threads(self): with futures.ThreadPoolExecutor(1) as pool: diff --git a/tornado/test/stack_context_test.py b/tornado/test/stack_context_test.py index 4bbba5213..53c640f8a 100644 --- a/tornado/test/stack_context_test.py +++ b/tornado/test/stack_context_test.py @@ -169,5 +169,6 @@ class StackContextTest(AsyncTestCase): self.io_loop.add_callback(f1) self.wait() + if __name__ == '__main__': unittest.main() diff --git a/website/sphinx/releases/next.rst b/website/sphinx/releases/next.rst index d784660e9..e996d79fb 100644 --- a/website/sphinx/releases/next.rst +++ b/website/sphinx/releases/next.rst @@ -194,3 +194,5 @@ In progress that are passed to the ``get``/``post``/etc method. These attributes are set before those methods are called, so they are available during ``prepare()`` +* `IOLoop.add_callback` and `add_callback_from_signal` now take + ``*args, **kwargs`` to pass along to the callback.