]> git.ipfire.org Git - thirdparty/tornado.git/commitdiff
Add streaming_callback to IOStream.read_bytes and read_until_close.
authorBen Darnell <ben@bendarnell.com>
Fri, 9 Sep 2011 05:07:32 +0000 (22:07 -0700)
committerBen Darnell <ben@bendarnell.com>
Fri, 9 Sep 2011 05:07:32 +0000 (22:07 -0700)
Closes #300.

tornado/iostream.py
tornado/test/iostream_test.py

index 1a4f4992f266a5b379b2f8bf4652db2aa3438169..5869ec8c06453dfe7c197ff28917dfa3e23508b3 100644 (file)
@@ -37,11 +37,9 @@ except ImportError:
 class IOStream(object):
     r"""A utility class to write to and read from a non-blocking socket.
 
-    We support three methods: write(), read_until(), and read_bytes().
+    We support a non-blocking ``write()`` and a family of ``read_*()`` methods.
     All of the methods take callbacks (since writing and reading are
-    non-blocking and asynchronous). read_until() reads the socket until
-    a given delimiter, and read_bytes() reads until a specified number
-    of bytes have been read from the socket.
+    non-blocking and asynchronous). 
 
     The socket parameter may either be connected or unconnected.  For
     server operations the socket is the result of calling socket.accept().
@@ -94,6 +92,7 @@ class IOStream(object):
         self._read_bytes = None
         self._read_until_close = False
         self._read_callback = None
+        self._streaming_callback = None
         self._write_callback = None
         self._close_callback = None
         self._connect_callback = None
@@ -154,12 +153,18 @@ class IOStream(object):
                 break
         self._add_io_state(self.io_loop.READ)
 
-    def read_bytes(self, num_bytes, callback):
-        """Call callback when we read the given number of bytes."""
+    def read_bytes(self, num_bytes, callback, streaming_callback=None):
+        """Call callback when we read the given number of bytes.
+
+        If a ``streaming_callback`` is given, it will be called with chunks
+        of data as they become available, and the argument to the final
+        ``callback`` will be empty.
+        """
         assert not self._read_callback, "Already reading"
         assert isinstance(num_bytes, int)
         self._read_bytes = num_bytes
         self._read_callback = stack_context.wrap(callback)
+        self._streaming_callback = stack_context.wrap(streaming_callback)
         while True:
             if self._read_from_buffer():
                 return
@@ -168,10 +173,15 @@ class IOStream(object):
                 break
         self._add_io_state(self.io_loop.READ)
 
-    def read_until_close(self, callback):
+    def read_until_close(self, callback, streaming_callback=None):
         """Reads all data from the socket until it is closed.
 
-        Subject to ``max_buffer_size`` limit from `IOStream` constructor.
+        If a ``streaming_callback`` is given, it will be called with chunks
+        of data as they become available, and the argument to the final
+        ``callback`` will be empty.
+
+        Subject to ``max_buffer_size`` limit from `IOStream` constructor if
+        a ``streaming_callback`` is not used.
         """
         assert not self._read_callback, "Already reading"
         if self.closed():
@@ -179,6 +189,7 @@ class IOStream(object):
             return
         self._read_until_close = True
         self._read_callback = stack_context.wrap(callback)
+        self._streaming_callback = stack_context.wrap(streaming_callback)
         self._add_io_state(self.io_loop.READ)
 
     def write(self, data, callback=None):
@@ -372,10 +383,16 @@ class IOStream(object):
         Returns True if the read was completed.
         """
         if self._read_bytes is not None:
+            if self._streaming_callback is not None and self._read_buffer_size:
+                bytes_to_consume = min(self._read_bytes, self._read_buffer_size)
+                self._read_bytes -= bytes_to_consume
+                self._run_callback(self._streaming_callback,
+                                   self._consume(bytes_to_consume))
             if self._read_buffer_size >= self._read_bytes:
                 num_bytes = self._read_bytes
                 callback = self._read_callback
                 self._read_callback = None
+                self._streaming_callback = None
                 self._read_bytes = None
                 self._run_callback(callback, self._consume(num_bytes))
                 return True
@@ -386,6 +403,7 @@ class IOStream(object):
                 callback = self._read_callback
                 delimiter_len = len(self._read_delimiter)
                 self._read_callback = None
+                self._streaming_callback = None
                 self._read_delimiter = None
                 self._run_callback(callback,
                                    self._consume(loc + delimiter_len))
@@ -396,9 +414,14 @@ class IOStream(object):
             if m:
                 callback = self._read_callback
                 self._read_callback = None
+                self._streaming_callback = None
                 self._read_regex = None
                 self._run_callback(callback, self._consume(m.end()))
                 return True
+        elif self._read_until_close:
+            if self._streaming_callback is not None and self._read_buffer_size:
+                self._run_callback(self._streaming_callback,
+                                   self._consume(self._read_buffer_size))
         return False
 
     def _handle_connect(self):
index 72df5f0a0794c57f745215b3f2c9bdfd06111d7a..9ba5575664acd5b0b9e414916008ec625dd4e524 100644 (file)
@@ -1,3 +1,4 @@
+from tornado import netutil
 from tornado.iostream import IOStream
 from tornado.testing import AsyncHTTPTestCase, LogTrapTestCase, get_unused_port
 from tornado.util import b
@@ -12,6 +13,25 @@ class TestIOStream(AsyncHTTPTestCase, LogTrapTestCase):
     def get_app(self):
         return Application([('/', HelloHandler)])
 
+    def make_iostream_pair(self):
+        port = get_unused_port()
+        [listener] = netutil.bind_sockets(port, '127.0.0.1',
+                                          family=socket.AF_INET)
+        streams = [None, None]
+        def accept_callback(connection, address):
+            streams[0] = IOStream(connection, io_loop=self.io_loop)
+            self.stop()
+        def connect_callback():
+            streams[1] = client_stream
+            self.stop()
+        netutil.add_accept_handler(listener, accept_callback,
+                                   io_loop=self.io_loop)
+        client_stream = IOStream(socket.socket(), io_loop=self.io_loop)
+        client_stream.connect(('127.0.0.1', port),
+                              callback=connect_callback)
+        self.wait(condition=lambda: all(streams))
+        return streams
+
     def test_read_zero_bytes(self):
         s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
         s.connect(("localhost", self.get_http_port()))
@@ -67,3 +87,51 @@ class TestIOStream(AsyncHTTPTestCase, LogTrapTestCase):
         data = self.wait()
         self.assertTrue(data.startswith(b("HTTP/1.0 200")))
         self.assertTrue(data.endswith(b("Hello")))
+
+    def test_streaming_callback(self):
+        server, client = self.make_iostream_pair()
+        try:
+            chunks = []
+            final_called = []
+            def streaming_callback(data):
+                chunks.append(data)
+                self.stop()
+            def final_callback(data):
+                assert not data
+                final_called.append(True)
+                self.stop()
+            server.read_bytes(6, callback=final_callback,
+                              streaming_callback=streaming_callback)
+            client.write(b("1234"))
+            self.wait(condition=lambda: chunks)
+            client.write(b("5678"))
+            self.wait(condition=lambda: final_called)
+            self.assertEqual(chunks, [b("1234"), b("56")])
+
+            # the rest of the last chunk is still in the buffer
+            server.read_bytes(2, callback=self.stop)
+            data = self.wait()
+            self.assertEqual(data, b("78"))
+        finally:
+            server.close()
+            client.close()
+
+    def test_streaming_until_close(self):
+        server, client = self.make_iostream_pair()
+        try:
+            chunks = []
+            def callback(data):
+                chunks.append(data)
+                self.stop()
+            client.read_until_close(callback=callback,
+                                    streaming_callback=callback)
+            server.write(b("1234"))
+            self.wait()
+            server.write(b("5678"))
+            self.wait()
+            server.close()
+            self.wait()
+            self.assertEqual(chunks, [b("1234"), b("5678"), b("")])
+        finally:
+            server.close()
+            client.close()