]> git.ipfire.org Git - thirdparty/tornado.git/commitdiff
Flesh out future-oriented client interfaces.
authorBen Darnell <ben@bendarnell.com>
Sun, 2 Sep 2012 17:50:57 +0000 (10:50 -0700)
committerBen Darnell <ben@bendarnell.com>
Sun, 2 Sep 2012 17:52:53 +0000 (10:52 -0700)
Make DummyFuture a more complete imitation of Future.

Add tests demonstrating various implementation styles.

Add a decorator that simplifies use of futures.

tornado/concurrent.py
tornado/test/concurrent_test.py [new file with mode: 0644]
tornado/test/runtests.py

index 9bd0bac37d08d0ed25f404c116e8b41313db94f5..80596844db3161ce00f28706bc6d7c2c2b22b330 100644 (file)
@@ -18,6 +18,7 @@ from __future__ import absolute_import, division, with_statement
 import functools
 import sys
 
+from tornado.stack_context import ExceptionStackContext
 from tornado.util import raise_exc_info
 
 try:
@@ -27,9 +28,11 @@ except ImportError:
 
 
 class DummyFuture(object):
-    def __init__(self, result, exc_info=None):
-        self._result = result
-        self._exc_info = exc_info
+    def __init__(self):
+        self._done = False
+        self._result = None
+        self._exception = None
+        self._callbacks = []
 
     def cancel(self):
         return False
@@ -38,32 +41,62 @@ class DummyFuture(object):
         return False
 
     def running(self):
-        return False
+        return not self._done
 
     def done(self):
-        return True
+        return self._done
 
     def result(self, timeout=None):
-        if self._exc_info:
-            raise_exc_info(self._exc_info)
+        self._check_done()
+        if self._exception:
+            raise self._exception
         return self._result
 
     def exception(self, timeout=None):
-        if self._exc_info:
-            return self._exc_info[1]
+        self._check_done()
+        if self._exception:
+            return self._exception
         else:
             return None
 
     def add_done_callback(self, fn):
-        fn(self)
+        if self._done:
+            fn(self)
+        else:
+            self._callbacks.append(fn)
+
+    def set_result(self, result):
+        self._result = result
+        self._set_done()
+
+    def set_exception(self, exception):
+        self._exception = exception
+        self._set_done()
 
+    def _check_done(self):
+        if not self._done:
+            raise Exception("DummyFuture does not support blocking for results")
+
+    def _set_done(self):
+        self._done = True
+        for cb in self._callbacks:
+            # TODO: error handling
+            cb(self)
+        self._callbacks = None
+
+if futures is None:
+    Future = DummyFuture
+else:
+    Future = futures.Future
 
 class DummyExecutor(object):
     def submit(self, fn, *args, **kwargs):
+        future = Future()
         try:
-            return DummyFuture(fn(*args, **kwargs))
-        except Exception:
-            return DummyFuture(result=None, exc_info=sys.exc_info())
+            future.set_result(fn(*args, **kwargs))
+        except Exception, e:
+            future.set_exception(e)
+        return future
 
 dummy_executor = DummyExecutor()
 
@@ -76,3 +109,19 @@ def run_on_executor(fn):
             self.io_loop.add_future(future, callback)
         return future
     return wrapper
+
+# TODO: this needs a better name
+def future_wrap(f):
+    @functools.wraps(f)
+    def wrapper(*args, **kwargs):
+        future = Future()
+        if kwargs.get('callback') is not None:
+            future.add_done_callback(kwargs.pop('callback'))
+        kwargs['callback'] = future.set_result
+        def handle_error(typ, value, tb):
+            future.set_exception(value)
+            return True
+        with ExceptionStackContext(handle_error):
+            f(*args, **kwargs)
+        return future
+    return wrapper
diff --git a/tornado/test/concurrent_test.py b/tornado/test/concurrent_test.py
new file mode 100644 (file)
index 0000000..87f590e
--- /dev/null
@@ -0,0 +1,171 @@
+#!/usr/bin/env python
+#
+# Copyright 2012 Facebook
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+from __future__ import absolute_import, division, with_statement
+
+import logging
+import re
+import socket
+
+from tornado.concurrent import Future, future_wrap
+from tornado.escape import utf8, to_unicode
+from tornado import gen
+from tornado.iostream import IOStream
+from tornado.netutil import TCPServer
+from tornado.testing import AsyncTestCase, LogTrapTestCase, get_unused_port
+from tornado.util import b
+
+class CapServer(TCPServer):
+    def handle_stream(self, stream, address):
+        logging.info("handle_stream")
+        self.stream = stream
+        self.stream.read_until(b("\n"), self.handle_read)
+
+    def handle_read(self, data):
+        logging.info("handle_read")
+        data = to_unicode(data)
+        if data == data.upper():
+            self.stream.write(b("error\talready capitalized\n"))
+        else:
+            # data already has \n
+            self.stream.write(utf8("ok\t%s" % data.upper()))
+        self.stream.close()
+
+
+class CapError(Exception):
+    pass
+
+
+class BaseCapClient(object):
+    def __init__(self, port, io_loop):
+        self.port = port
+        self.io_loop = io_loop
+
+    def process_response(self, data):
+        status, message = re.match('(.*)\t(.*)\n', to_unicode(data)).groups()
+        if status == 'ok':
+            return message
+        else:
+            raise CapError(message)
+
+
+class ManualCapClient(BaseCapClient):
+    def capitalize(self, request_data, callback=None):
+        logging.info("capitalize")
+        self.request_data = request_data
+        self.stream = IOStream(socket.socket(), io_loop=self.io_loop)
+        self.stream.connect(('127.0.0.1', self.port),
+                            callback=self.handle_connect)
+        self.future = Future()
+        if callback is not None:
+            self.future.add_done_callback(callback)
+        return self.future
+
+    def handle_connect(self):
+        logging.info("handle_connect")
+        self.stream.write(utf8(self.request_data + "\n"))
+        self.stream.read_until(b('\n'), callback=self.handle_read)
+
+    def handle_read(self, data):
+        logging.info("handle_read")
+        self.stream.close()
+        try:
+            self.future.set_result(self.process_response(data))
+        except CapError, e:
+            self.future.set_exception(e)
+
+
+class DecoratorCapClient(BaseCapClient):
+    @future_wrap
+    def capitalize(self, request_data, callback):
+        logging.info("capitalize")
+        self.request_data = request_data
+        self.stream = IOStream(socket.socket(), io_loop=self.io_loop)
+        self.stream.connect(('127.0.0.1', self.port),
+                            callback=self.handle_connect)
+        self.callback = callback
+
+    def handle_connect(self):
+        logging.info("handle_connect")
+        self.stream.write(utf8(self.request_data + "\n"))
+        self.stream.read_until(b('\n'), callback=self.handle_read)
+
+    def handle_read(self, data):
+        logging.info("handle_read")
+        self.stream.close()
+        self.callback(self.process_response(data))
+
+
+class GeneratorCapClient(BaseCapClient):
+    @future_wrap
+    @gen.engine
+    def capitalize(self, request_data, callback):
+        logging.info('capitalize')
+        stream = IOStream(socket.socket(), io_loop=self.io_loop)
+        logging.info('connecting')
+        yield gen.Task(stream.connect, ('127.0.0.1', self.port))
+        stream.write(utf8(request_data + '\n'))
+        logging.info('reading')
+        data = yield gen.Task(stream.read_until, b('\n'))
+        logging.info('returning')
+        stream.close()
+        callback(self.process_response(data))
+
+
+class ClientTestMixin(object):
+    def setUp(self):
+        super(ClientTestMixin, self).setUp()
+        self.server = CapServer(io_loop=self.io_loop)
+        port = get_unused_port()
+        self.server.listen(port, address='127.0.0.1')
+        self.client = self.client_class(io_loop=self.io_loop, port=port)
+
+    def tearDown(self):
+        self.server.stop()
+        super(ClientTestMixin, self).tearDown()
+
+    def test_callback(self):
+        self.client.capitalize("hello", callback=self.stop)
+        future = self.wait()
+        self.assertEqual(future.result(), "HELLO")
+
+    def test_callback_error(self):
+        self.client.capitalize("HELLO", callback=self.stop)
+        future = self.wait()
+        self.assertRaisesRegexp(CapError, "already capitalized", future.result)
+
+    def test_future(self):
+        future = self.client.capitalize("hello")
+        self.io_loop.add_future(future, self.stop)
+        self.wait()
+        self.assertEqual(future.result(), "HELLO")
+
+    def test_future_error(self):
+        future = self.client.capitalize("HELLO")
+        self.io_loop.add_future(future, self.stop)
+        self.wait()
+        self.assertRaisesRegexp(CapError, "already capitalized", future.result)
+
+
+class ManualClientTest(ClientTestMixin, AsyncTestCase, LogTrapTestCase):
+    client_class = ManualCapClient
+
+
+class DecoratorClientTest(ClientTestMixin, AsyncTestCase, LogTrapTestCase):
+    client_class = DecoratorCapClient
+
+
+class GeneratorClientTest(ClientTestMixin, AsyncTestCase, LogTrapTestCase):
+    client_class = GeneratorCapClient
index 90d94400f992c3f8c1e97b266f0e2026e63257d4..174554989ee0c0e6b74096bb02452f643b55e860 100644 (file)
@@ -10,6 +10,7 @@ TEST_MODULES = [
     'tornado.iostream.doctests',
     'tornado.util.doctests',
     'tornado.test.auth_test',
+    'tornado.test.concurrent_test',
     'tornado.test.curl_httpclient_test',
     'tornado.test.escape_test',
     'tornado.test.gen_test',
@@ -65,6 +66,11 @@ if __name__ == '__main__':
     warnings.filterwarnings("ignore", category=DeprecationWarning)
     warnings.filterwarnings("error", category=DeprecationWarning,
                             module=r"tornado\..*")
+    # The unittest module is aggressive about deprecating redundant methods,
+    # leaving some without non-deprecated spellings that work on both
+    # 2.7 and 3.2
+    warnings.filterwarnings("ignore", category=DeprecationWarning,
+                            message="Please use assert.* instead")
 
     import tornado.testing
     kwargs = {}