]> git.ipfire.org Git - thirdparty/tornado.git/commitdiff
Add a 'partial' flag to IOStream.read_bytes.
authorBen Darnell <ben@bendarnell.com>
Sat, 29 Mar 2014 15:45:52 +0000 (15:45 +0000)
committerBen Darnell <ben@bendarnell.com>
Sat, 29 Mar 2014 15:46:03 +0000 (15:46 +0000)
This is a more coroutine-friendly alternative to streaming_callback.

tornado/iostream.py
tornado/test/iostream_test.py

index 113ee094d8ce1dea0438dc17a2dd5e3858e9b9f7..98afea9c29aaac729cb24fd5c47999d7d4e0b427 100644 (file)
@@ -93,6 +93,7 @@ class BaseIOStream(object):
         self._read_delimiter = None
         self._read_regex = None
         self._read_bytes = None
+        self._read_partial = False
         self._read_until_close = False
         self._read_callback = None
         self._read_future = None
@@ -168,17 +169,22 @@ class BaseIOStream(object):
         self._try_inline_read()
         return future
 
-    def read_bytes(self, num_bytes, callback=None, streaming_callback=None):
+    def read_bytes(self, num_bytes, callback=None, streaming_callback=None,
+                   partial=False):
         """Run callback when we read the given number of bytes.
 
         If a ``streaming_callback`` is given, it will be called with chunks
         of data as they become available, and the argument to the final
         ``callback`` will be empty.  Otherwise, the ``callback`` gets
         the data as an argument.
+
+        If ``partial`` is true, the callback is run as soon as we have
+        any bytes to return (but never more than ``num_bytes``)
         """
         future = self._set_read_callback(callback)
         assert isinstance(num_bytes, numbers.Integral)
         self._read_bytes = num_bytes
+        self._read_partial = partial
         self._streaming_callback = stack_context.wrap(streaming_callback)
         self._try_inline_read()
         return future
@@ -526,9 +532,12 @@ class BaseIOStream(object):
                 self._read_bytes -= bytes_to_consume
             self._run_callback(self._streaming_callback,
                                self._consume(bytes_to_consume))
-        if self._read_bytes is not None and self._read_buffer_size >= self._read_bytes:
-            num_bytes = self._read_bytes
+        if (self._read_bytes is not None and
+            (self._read_buffer_size >= self._read_bytes or
+             (self._read_partial and self._read_buffer_size > 0))):
+            num_bytes = min(self._read_bytes, self._read_buffer_size)
             self._read_bytes = None
+            self._read_partial = False
             self._run_read_callback(self._consume(num_bytes))
             return True
         elif self._read_delimiter is not None:
index 24b37b81c8c81ac3064be7e65f86c5917f66ccf9..f721d09c6a95199cfa3269b0a3210858700331c4 100644 (file)
@@ -533,6 +533,31 @@ class TestIOStreamMixin(object):
             server.close()
             client.close()
 
+    def test_read_bytes_partial(self):
+        server, client = self.make_iostream_pair()
+        try:
+            # Ask for more than is available with partial=True
+            client.read_bytes(50, self.stop, partial=True)
+            server.write(b"hello")
+            data = self.wait()
+            self.assertEqual(data, b"hello")
+
+            # Ask for less than what is available; num_bytes is still
+            # respected.
+            client.read_bytes(3, self.stop, partial=True)
+            server.write(b"world")
+            data = self.wait()
+            self.assertEqual(data, b"wor")
+
+            # Partial reads won't return an empty string, but read_bytes(0)
+            # will.
+            client.read_bytes(0, self.stop, partial=True)
+            data = self.wait()
+            self.assertEqual(data, b'')
+        finally:
+            server.close()
+            client.close()
+
 
 class TestIOStreamWebHTTP(TestIOStreamWebMixin, AsyncHTTPTestCase):
     def _make_client_iostream(self):