From c9fb8e03125675563a2a5543d205d4f31fd6c6cb Mon Sep 17 00:00:00 2001 From: Ben Darnell Date: Thu, 8 Sep 2011 22:07:32 -0700 Subject: [PATCH] Add streaming_callback to IOStream.read_bytes and read_until_close. Closes #300. --- tornado/iostream.py | 39 +++++++++++++++----- tornado/test/iostream_test.py | 68 +++++++++++++++++++++++++++++++++++ 2 files changed, 99 insertions(+), 8 deletions(-) diff --git a/tornado/iostream.py b/tornado/iostream.py index 1a4f4992f..5869ec8c0 100644 --- a/tornado/iostream.py +++ b/tornado/iostream.py @@ -37,11 +37,9 @@ except ImportError: class IOStream(object): r"""A utility class to write to and read from a non-blocking socket. - We support three methods: write(), read_until(), and read_bytes(). + We support a non-blocking ``write()`` and a family of ``read_*()`` methods. All of the methods take callbacks (since writing and reading are - non-blocking and asynchronous). read_until() reads the socket until - a given delimiter, and read_bytes() reads until a specified number - of bytes have been read from the socket. + non-blocking and asynchronous). The socket parameter may either be connected or unconnected. For server operations the socket is the result of calling socket.accept(). @@ -94,6 +92,7 @@ class IOStream(object): self._read_bytes = None self._read_until_close = False self._read_callback = None + self._streaming_callback = None self._write_callback = None self._close_callback = None self._connect_callback = None @@ -154,12 +153,18 @@ class IOStream(object): break self._add_io_state(self.io_loop.READ) - def read_bytes(self, num_bytes, callback): - """Call callback when we read the given number of bytes.""" + def read_bytes(self, num_bytes, callback, streaming_callback=None): + """Call callback when we read the given number of bytes. + + If a ``streaming_callback`` is given, it will be called with chunks + of data as they become available, and the argument to the final + ``callback`` will be empty. + """ assert not self._read_callback, "Already reading" assert isinstance(num_bytes, int) self._read_bytes = num_bytes self._read_callback = stack_context.wrap(callback) + self._streaming_callback = stack_context.wrap(streaming_callback) while True: if self._read_from_buffer(): return @@ -168,10 +173,15 @@ class IOStream(object): break self._add_io_state(self.io_loop.READ) - def read_until_close(self, callback): + def read_until_close(self, callback, streaming_callback=None): """Reads all data from the socket until it is closed. - Subject to ``max_buffer_size`` limit from `IOStream` constructor. + If a ``streaming_callback`` is given, it will be called with chunks + of data as they become available, and the argument to the final + ``callback`` will be empty. + + Subject to ``max_buffer_size`` limit from `IOStream` constructor if + a ``streaming_callback`` is not used. """ assert not self._read_callback, "Already reading" if self.closed(): @@ -179,6 +189,7 @@ class IOStream(object): return self._read_until_close = True self._read_callback = stack_context.wrap(callback) + self._streaming_callback = stack_context.wrap(streaming_callback) self._add_io_state(self.io_loop.READ) def write(self, data, callback=None): @@ -372,10 +383,16 @@ class IOStream(object): Returns True if the read was completed. """ if self._read_bytes is not None: + if self._streaming_callback is not None and self._read_buffer_size: + bytes_to_consume = min(self._read_bytes, self._read_buffer_size) + self._read_bytes -= bytes_to_consume + self._run_callback(self._streaming_callback, + self._consume(bytes_to_consume)) if self._read_buffer_size >= self._read_bytes: num_bytes = self._read_bytes callback = self._read_callback self._read_callback = None + self._streaming_callback = None self._read_bytes = None self._run_callback(callback, self._consume(num_bytes)) return True @@ -386,6 +403,7 @@ class IOStream(object): callback = self._read_callback delimiter_len = len(self._read_delimiter) self._read_callback = None + self._streaming_callback = None self._read_delimiter = None self._run_callback(callback, self._consume(loc + delimiter_len)) @@ -396,9 +414,14 @@ class IOStream(object): if m: callback = self._read_callback self._read_callback = None + self._streaming_callback = None self._read_regex = None self._run_callback(callback, self._consume(m.end())) return True + elif self._read_until_close: + if self._streaming_callback is not None and self._read_buffer_size: + self._run_callback(self._streaming_callback, + self._consume(self._read_buffer_size)) return False def _handle_connect(self): diff --git a/tornado/test/iostream_test.py b/tornado/test/iostream_test.py index 72df5f0a0..9ba557566 100644 --- a/tornado/test/iostream_test.py +++ b/tornado/test/iostream_test.py @@ -1,3 +1,4 @@ +from tornado import netutil from tornado.iostream import IOStream from tornado.testing import AsyncHTTPTestCase, LogTrapTestCase, get_unused_port from tornado.util import b @@ -12,6 +13,25 @@ class TestIOStream(AsyncHTTPTestCase, LogTrapTestCase): def get_app(self): return Application([('/', HelloHandler)]) + def make_iostream_pair(self): + port = get_unused_port() + [listener] = netutil.bind_sockets(port, '127.0.0.1', + family=socket.AF_INET) + streams = [None, None] + def accept_callback(connection, address): + streams[0] = IOStream(connection, io_loop=self.io_loop) + self.stop() + def connect_callback(): + streams[1] = client_stream + self.stop() + netutil.add_accept_handler(listener, accept_callback, + io_loop=self.io_loop) + client_stream = IOStream(socket.socket(), io_loop=self.io_loop) + client_stream.connect(('127.0.0.1', port), + callback=connect_callback) + self.wait(condition=lambda: all(streams)) + return streams + def test_read_zero_bytes(self): s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0) s.connect(("localhost", self.get_http_port())) @@ -67,3 +87,51 @@ class TestIOStream(AsyncHTTPTestCase, LogTrapTestCase): data = self.wait() self.assertTrue(data.startswith(b("HTTP/1.0 200"))) self.assertTrue(data.endswith(b("Hello"))) + + def test_streaming_callback(self): + server, client = self.make_iostream_pair() + try: + chunks = [] + final_called = [] + def streaming_callback(data): + chunks.append(data) + self.stop() + def final_callback(data): + assert not data + final_called.append(True) + self.stop() + server.read_bytes(6, callback=final_callback, + streaming_callback=streaming_callback) + client.write(b("1234")) + self.wait(condition=lambda: chunks) + client.write(b("5678")) + self.wait(condition=lambda: final_called) + self.assertEqual(chunks, [b("1234"), b("56")]) + + # the rest of the last chunk is still in the buffer + server.read_bytes(2, callback=self.stop) + data = self.wait() + self.assertEqual(data, b("78")) + finally: + server.close() + client.close() + + def test_streaming_until_close(self): + server, client = self.make_iostream_pair() + try: + chunks = [] + def callback(data): + chunks.append(data) + self.stop() + client.read_until_close(callback=callback, + streaming_callback=callback) + server.write(b("1234")) + self.wait() + server.write(b("5678")) + self.wait() + server.close() + self.wait() + self.assertEqual(chunks, [b("1234"), b("5678"), b("")]) + finally: + server.close() + client.close() -- 2.47.2