]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
gh-81322: support multiple separators in StreamReader.readuntil (#16429)
authorBruce Merry <1963944+bmerry@users.noreply.github.com>
Mon, 8 Apr 2024 16:58:02 +0000 (18:58 +0200)
committerGitHub <noreply@github.com>
Mon, 8 Apr 2024 16:58:02 +0000 (09:58 -0700)
Doc/library/asyncio-stream.rst
Lib/asyncio/streams.py
Lib/test/test_asyncio/test_streams.py
Misc/NEWS.d/next/Library/2019-09-26-17-52-52.bpo-37141.onYY2-.rst [new file with mode: 0644]

index 68b1dff20213e158d1f180cec65aae78d04a6dbe..6231b49b1e2431b0bc119b51e1004ca18aabc657 100644 (file)
@@ -260,8 +260,19 @@ StreamReader
       buffer is reset.  The :attr:`IncompleteReadError.partial` attribute
       may contain a portion of the separator.
 
+      The *separator* may also be an :term:`iterable` of separators. In this
+      case the return value will be the shortest possible that has any
+      separator as the suffix. For the purposes of :exc:`LimitOverrunError`,
+      the shortest possible separator is considered to be the one that
+      matched.
+
       .. versionadded:: 3.5.2
 
+      .. versionchanged:: 3.13
+
+         The *separator* parameter may now be an :term:`iterable` of
+         separators.
+
    .. method:: at_eof()
 
       Return ``True`` if the buffer is empty and :meth:`feed_eof`
index 3fe52dbac25c916eec8855c19060da80e4702a7f..4517ca22d74637943194d4cf98d30dc9487ac0a9 100644 (file)
@@ -590,20 +590,34 @@ class StreamReader:
         If the data cannot be read because of over limit, a
         LimitOverrunError exception  will be raised, and the data
         will be left in the internal buffer, so it can be read again.
+
+        The ``separator`` may also be an iterable of separators. In this
+        case the return value will be the shortest possible that has any
+        separator as the suffix. For the purposes of LimitOverrunError,
+        the shortest possible separator is considered to be the one that
+        matched.
         """
-        seplen = len(separator)
-        if seplen == 0:
+        if isinstance(separator, bytes):
+            separator = [separator]
+        else:
+            # Makes sure shortest matches wins, and supports arbitrary iterables
+            separator = sorted(separator, key=len)
+        if not separator:
+            raise ValueError('Separator should contain at least one element')
+        min_seplen = len(separator[0])
+        max_seplen = len(separator[-1])
+        if min_seplen == 0:
             raise ValueError('Separator should be at least one-byte string')
 
         if self._exception is not None:
             raise self._exception
 
         # Consume whole buffer except last bytes, which length is
-        # one less than seplen. Let's check corner cases with
-        # separator='SEPARATOR':
+        # one less than max_seplen. Let's check corner cases with
+        # separator[-1]='SEPARATOR':
         # * we have received almost complete separator (without last
         #   byte). i.e buffer='some textSEPARATO'. In this case we
-        #   can safely consume len(separator) - 1 bytes.
+        #   can safely consume max_seplen - 1 bytes.
         # * last byte of buffer is first byte of separator, i.e.
         #   buffer='abcdefghijklmnopqrS'. We may safely consume
         #   everything except that last byte, but this require to
@@ -616,26 +630,35 @@ class StreamReader:
         #   messages :)
 
         # `offset` is the number of bytes from the beginning of the buffer
-        # where there is no occurrence of `separator`.
+        # where there is no occurrence of any `separator`.
         offset = 0
 
-        # Loop until we find `separator` in the buffer, exceed the buffer size,
+        # Loop until we find `separator` in the buffer, exceed the buffer size,
         # or an EOF has happened.
         while True:
             buflen = len(self._buffer)
 
-            # Check if we now have enough data in the buffer for `separator` to
-            # fit.
-            if buflen - offset >= seplen:
-                isep = self._buffer.find(separator, offset)
-
-                if isep != -1:
-                    # `separator` is in the buffer. `isep` will be used later
-                    # to retrieve the data.
+            # Check if we now have enough data in the buffer for shortest
+            # separator to fit.
+            if buflen - offset >= min_seplen:
+                match_start = None
+                match_end = None
+                for sep in separator:
+                    isep = self._buffer.find(sep, offset)
+
+                    if isep != -1:
+                        # `separator` is in the buffer. `match_start` and
+                        # `match_end` will be used later to retrieve the
+                        # data.
+                        end = isep + len(sep)
+                        if match_end is None or end < match_end:
+                            match_end = end
+                            match_start = isep
+                if match_end is not None:
                     break
 
                 # see upper comment for explanation.
-                offset = buflen + 1 - seplen
+                offset = max(0, buflen + 1 - max_seplen)
                 if offset > self._limit:
                     raise exceptions.LimitOverrunError(
                         'Separator is not found, and chunk exceed the limit',
@@ -644,7 +667,7 @@ class StreamReader:
             # Complete message (with full separator) may be present in buffer
             # even when EOF flag is set. This may happen when the last chunk
             # adds data which makes separator be found. That's why we check for
-            # EOF *ater* inspecting the buffer.
+            # EOF *after* inspecting the buffer.
             if self._eof:
                 chunk = bytes(self._buffer)
                 self._buffer.clear()
@@ -653,12 +676,12 @@ class StreamReader:
             # _wait_for_data() will resume reading if stream was paused.
             await self._wait_for_data('readuntil')
 
-        if isep > self._limit:
+        if match_start > self._limit:
             raise exceptions.LimitOverrunError(
-                'Separator is found, but chunk is longer than limit', isep)
+                'Separator is found, but chunk is longer than limit', match_start)
 
-        chunk = self._buffer[:isep + seplen]
-        del self._buffer[:isep + seplen]
+        chunk = self._buffer[:match_end]
+        del self._buffer[:match_end]
         self._maybe_resume_transport()
         return bytes(chunk)
 
index 2cf48538d5d30d241afd1baa916df4292073e02f..792e88761acdc22c2b39792a3c8c6879772dcec4 100644 (file)
@@ -383,6 +383,10 @@ class StreamTests(test_utils.TestCase):
         stream = asyncio.StreamReader(loop=self.loop)
         with self.assertRaisesRegex(ValueError, 'Separator should be'):
             self.loop.run_until_complete(stream.readuntil(separator=b''))
+        with self.assertRaisesRegex(ValueError, 'Separator should be'):
+            self.loop.run_until_complete(stream.readuntil(separator=[b'']))
+        with self.assertRaisesRegex(ValueError, 'Separator should contain'):
+            self.loop.run_until_complete(stream.readuntil(separator=[]))
 
     def test_readuntil_multi_chunks(self):
         stream = asyncio.StreamReader(loop=self.loop)
@@ -466,6 +470,48 @@ class StreamTests(test_utils.TestCase):
 
         self.assertEqual(b'some dataAAA', stream._buffer)
 
+    def test_readuntil_multi_separator(self):
+        stream = asyncio.StreamReader(loop=self.loop)
+
+        # Simple case
+        stream.feed_data(b'line 1\nline 2\r')
+        data = self.loop.run_until_complete(stream.readuntil([b'\r', b'\n']))
+        self.assertEqual(b'line 1\n', data)
+        data = self.loop.run_until_complete(stream.readuntil([b'\r', b'\n']))
+        self.assertEqual(b'line 2\r', data)
+        self.assertEqual(b'', stream._buffer)
+
+        # First end position matches, even if that's a longer match
+        stream.feed_data(b'ABCDEFG')
+        data = self.loop.run_until_complete(stream.readuntil([b'DEF', b'BCDE']))
+        self.assertEqual(b'ABCDE', data)
+        self.assertEqual(b'FG', stream._buffer)
+
+    def test_readuntil_multi_separator_limit(self):
+        stream = asyncio.StreamReader(loop=self.loop, limit=3)
+        stream.feed_data(b'some dataA')
+
+        with self.assertRaisesRegex(asyncio.LimitOverrunError,
+                                    'is found') as cm:
+            self.loop.run_until_complete(stream.readuntil([b'A', b'ome dataA']))
+
+        self.assertEqual(b'some dataA', stream._buffer)
+
+    def test_readuntil_multi_separator_negative_offset(self):
+        # If the buffer is big enough for the smallest separator (but does
+        # not contain it) but too small for the largest, `offset` must not
+        # become negative.
+        stream = asyncio.StreamReader(loop=self.loop)
+        stream.feed_data(b'data')
+
+        readuntil_task = self.loop.create_task(stream.readuntil([b'A', b'long sep']))
+        self.loop.call_soon(stream.feed_data, b'Z')
+        self.loop.call_soon(stream.feed_data, b'Aaaa')
+
+        data = self.loop.run_until_complete(readuntil_task)
+        self.assertEqual(b'dataZA', data)
+        self.assertEqual(b'aaa', stream._buffer)
+
     def test_readexactly_zero_or_less(self):
         # Read exact number of bytes (zero or less).
         stream = asyncio.StreamReader(loop=self.loop)
diff --git a/Misc/NEWS.d/next/Library/2019-09-26-17-52-52.bpo-37141.onYY2-.rst b/Misc/NEWS.d/next/Library/2019-09-26-17-52-52.bpo-37141.onYY2-.rst
new file mode 100644 (file)
index 0000000..d916f31
--- /dev/null
@@ -0,0 +1,2 @@
+Accept an iterable of separators in :meth:`asyncio.StreamReader.readuntil`, stopping
+when one of them is encountered.