]> git.ipfire.org Git - thirdparty/tornado.git/commitdiff
Add max_bytes parameter to read_until and read_until_regex.
authorBen Darnell <ben@bendarnell.com>
Sun, 6 Apr 2014 11:43:33 +0000 (12:43 +0100)
committerBen Darnell <ben@bendarnell.com>
Sun, 6 Apr 2014 12:03:31 +0000 (13:03 +0100)
tornado/iostream.py
tornado/test/iostream_test.py

index 98afea9c29aaac729cb24fd5c47999d7d4e0b427..90d2a29457a2d5041e8c9776cbabd693c5d56e16 100644 (file)
@@ -67,6 +67,15 @@ class StreamClosedError(IOError):
     pass
 
 
+class UnsatisfiableReadError(Exception):
+    """Exception raised when a read cannot be satisfied.
+
+    Raised by ``read_until`` and ``read_until_regex`` with a ``max_bytes``
+    argument.
+    """
+    pass
+
+
 class BaseIOStream(object):
     """A utility class to write to and read from a non-blocking file or socket.
 
@@ -92,6 +101,7 @@ class BaseIOStream(object):
         self._write_buffer_frozen = False
         self._read_delimiter = None
         self._read_regex = None
+        self._read_max_bytes = None
         self._read_bytes = None
         self._read_partial = False
         self._read_until_close = False
@@ -147,26 +157,48 @@ class BaseIOStream(object):
         """
         return None
 
-    def read_until_regex(self, regex, callback=None):
+    def read_until_regex(self, regex, callback=None, max_bytes=None):
         """Run ``callback`` when we read the given regex pattern.
 
         The callback will get the data read (including the data that
         matched the regex and anything that came before it) as an argument.
+
+        If ``max_bytes`` is not None, the connection will be closed
+        if more than ``max_bytes`` bytes have been read and the regex is
+        not satisfied.
         """
         future = self._set_read_callback(callback)
         self._read_regex = re.compile(regex)
-        self._try_inline_read()
+        self._read_max_bytes = max_bytes
+        try:
+            self._try_inline_read()
+        except UnsatisfiableReadError as e:
+            # Handle this the same way as in _handle_events.
+            gen_log.info("Unsatisfiable read, closing connection: %s" % e)
+            self.close(exc_info=True)
+            return future
         return future
 
-    def read_until(self, delimiter, callback=None):
+    def read_until(self, delimiter, callback=None, max_bytes=None):
         """Run ``callback`` when we read the given delimiter.
 
         The callback will get the data read (including the delimiter)
         as an argument.
+
+        If ``max_bytes`` is not None, the connection will be closed
+        if more than ``max_bytes`` bytes have been read and the delimiter
+        is not found.
         """
         future = self._set_read_callback(callback)
         self._read_delimiter = delimiter
-        self._try_inline_read()
+        self._read_max_bytes = max_bytes
+        try:
+            self._try_inline_read()
+        except UnsatisfiableReadError as e:
+            # Handle this the same way as in _handle_events.
+            gen_log.info("Unsatisfiable read, closing connection: %s" % e)
+            self.close(exc_info=True)
+            return future
         return future
 
     def read_bytes(self, num_bytes, callback=None, streaming_callback=None,
@@ -363,6 +395,9 @@ class BaseIOStream(object):
                     "shouldn't happen: _handle_events without self._state"
                 self._state = state
                 self.io_loop.update_handler(self.fileno(), self._state)
+        except UnsatisfiableReadError as e:
+            gen_log.info("Unsatisfiable read, closing connection: %s" % e)
+            self.close(exc_info=True)
         except Exception:
             gen_log.error("Uncaught exception, closing connection.",
                           exc_info=True)
@@ -554,6 +589,8 @@ class BaseIOStream(object):
                     loc = self._read_buffer[0].find(self._read_delimiter)
                     if loc != -1:
                         delimiter_len = len(self._read_delimiter)
+                        self._check_max_bytes(self._read_delimiter,
+                                              loc + delimiter_len)
                         self._read_delimiter = None
                         self._run_read_callback(
                             self._consume(loc + delimiter_len))
@@ -561,19 +598,31 @@ class BaseIOStream(object):
                     if len(self._read_buffer) == 1:
                         break
                     _double_prefix(self._read_buffer)
+                self._check_max_bytes(self._read_delimiter,
+                                      len(self._read_buffer[0]))
         elif self._read_regex is not None:
             if self._read_buffer:
                 while True:
                     m = self._read_regex.search(self._read_buffer[0])
                     if m is not None:
+                        self._check_max_bytes(self._read_regex, m.end())
                         self._read_regex = None
                         self._run_read_callback(self._consume(m.end()))
                         return True
                     if len(self._read_buffer) == 1:
                         break
                     _double_prefix(self._read_buffer)
+                self._check_max_bytes(self._read_regex,
+                                      len(self._read_buffer[0]))
         return False
 
+    def _check_max_bytes(self, delimiter, size):
+        if (self._read_max_bytes is not None and
+            size > self._read_max_bytes):
+            raise UnsatisfiableReadError(
+                "delimiter %r not found within %d bytes" % (
+                    delimiter, self._read_max_bytes))
+
     def _handle_write(self):
         while self._write_buffer:
             try:
index f721d09c6a95199cfa3269b0a3210858700331c4..893c3214d32e6c20d51403bbb785b03dae122e48 100644 (file)
@@ -558,6 +558,125 @@ class TestIOStreamMixin(object):
             server.close()
             client.close()
 
+    def test_read_until_max_bytes(self):
+        server, client = self.make_iostream_pair()
+        client.set_close_callback(lambda: self.stop("closed"))
+        try:
+            # Extra room under the limit
+            client.read_until(b"def", self.stop, max_bytes=50)
+            server.write(b"abcdef")
+            data = self.wait()
+            self.assertEqual(data, b"abcdef")
+
+            # Just enough space
+            client.read_until(b"def", self.stop, max_bytes=6)
+            server.write(b"abcdef")
+            data = self.wait()
+            self.assertEqual(data, b"abcdef")
+
+            # Not enough space, but we don't know it until all we can do is
+            # log a warning and close the connection.
+            with ExpectLog(gen_log, "Unsatisfiable read"):
+                client.read_until(b"def", self.stop, max_bytes=5)
+                server.write(b"123456")
+                data = self.wait()
+            self.assertEqual(data, "closed")
+        finally:
+            server.close()
+            client.close()
+
+    def test_read_until_max_bytes_inline(self):
+        server, client = self.make_iostream_pair()
+        client.set_close_callback(lambda: self.stop("closed"))
+        try:
+            # Similar to the error case in the previous test, but the
+            # server writes first so client reads are satisfied
+            # inline.  For consistency with the out-of-line case, we
+            # do not raise the error synchronously.
+            server.write(b"123456")
+            with ExpectLog(gen_log, "Unsatisfiable read"):
+                client.read_until(b"def", self.stop, max_bytes=5)
+                data = self.wait()
+            self.assertEqual(data, "closed")
+        finally:
+            server.close()
+            client.close()
+
+    def test_read_until_max_bytes_ignores_extra(self):
+        server, client = self.make_iostream_pair()
+        client.set_close_callback(lambda: self.stop("closed"))
+        try:
+            # Even though data that matches arrives the same packet that
+            # puts us over the limit, we fail the request because it was not
+            # found within the limit.
+            server.write(b"abcdef")
+            with ExpectLog(gen_log, "Unsatisfiable read"):
+                client.read_until(b"def", self.stop, max_bytes=5)
+                data = self.wait()
+            self.assertEqual(data, "closed")
+        finally:
+            server.close()
+            client.close()
+
+    def test_read_until_regex_max_bytes(self):
+        server, client = self.make_iostream_pair()
+        client.set_close_callback(lambda: self.stop("closed"))
+        try:
+            # Extra room under the limit
+            client.read_until_regex(b"def", self.stop, max_bytes=50)
+            server.write(b"abcdef")
+            data = self.wait()
+            self.assertEqual(data, b"abcdef")
+
+            # Just enough space
+            client.read_until_regex(b"def", self.stop, max_bytes=6)
+            server.write(b"abcdef")
+            data = self.wait()
+            self.assertEqual(data, b"abcdef")
+
+            # Not enough space, but we don't know it until all we can do is
+            # log a warning and close the connection.
+            with ExpectLog(gen_log, "Unsatisfiable read"):
+                client.read_until_regex(b"def", self.stop, max_bytes=5)
+                server.write(b"123456")
+                data = self.wait()
+            self.assertEqual(data, "closed")
+        finally:
+            server.close()
+            client.close()
+
+    def test_read_until_regex_max_bytes_inline(self):
+        server, client = self.make_iostream_pair()
+        client.set_close_callback(lambda: self.stop("closed"))
+        try:
+            # Similar to the error case in the previous test, but the
+            # server writes first so client reads are satisfied
+            # inline.  For consistency with the out-of-line case, we
+            # do not raise the error synchronously.
+            server.write(b"123456")
+            with ExpectLog(gen_log, "Unsatisfiable read"):
+                client.read_until_regex(b"def", self.stop, max_bytes=5)
+                data = self.wait()
+            self.assertEqual(data, "closed")
+        finally:
+            server.close()
+            client.close()
+
+    def test_read_until_regex_max_bytes_ignores_extra(self):
+        server, client = self.make_iostream_pair()
+        client.set_close_callback(lambda: self.stop("closed"))
+        try:
+            # Even though data that matches arrives the same packet that
+            # puts us over the limit, we fail the request because it was not
+            # found within the limit.
+            server.write(b"abcdef")
+            with ExpectLog(gen_log, "Unsatisfiable read"):
+                client.read_until_regex(b"def", self.stop, max_bytes=5)
+                data = self.wait()
+            self.assertEqual(data, "closed")
+        finally:
+            server.close()
+            client.close()
 
 class TestIOStreamWebHTTP(TestIOStreamWebMixin, AsyncHTTPTestCase):
     def _make_client_iostream(self):