]> git.ipfire.org Git - thirdparty/babel.git/commitdiff
Avoid hot `isinstance`ing in PO file parse loop
authorAarni Koskela <akx@iki.fi>
Mon, 17 Mar 2025 12:37:31 +0000 (14:37 +0200)
committerAarni Koskela <akx@iki.fi>
Fri, 21 Mar 2025 06:23:57 +0000 (08:23 +0200)
babel/messages/pofile.py
tests/messages/test_pofile.py

index 987193e90ded2663073057e00b6750b436015da7..8220637bcea7a32a073c692a04481f3b571a9ae0 100644 (file)
@@ -300,16 +300,25 @@ class PoFileParser:
 
     def parse(self, fileobj: IO[AnyStr] | Iterable[AnyStr]) -> None:
         """
-        Reads from the file-like object `fileobj` and adds any po file
-        units found in it to the `Catalog` supplied to the constructor.
+        Reads from the file-like object (or iterable of string-likes) `fileobj`
+        and adds any po file units found in it to the `Catalog`
+        supplied to the constructor.
+
+        All of the items in the iterable must be the same type; either `str`
+        or `bytes` (decoded with the catalog charset), but not a mixture.
         """
+        needs_decode = None
 
         for lineno, line in enumerate(fileobj):
             line = line.strip()
-            if not isinstance(line, str):
-                line = line.decode(self.catalog.charset)
+            if needs_decode is None:
+                # If we don't yet know whether we need to decode,
+                # let's find out now.
+                needs_decode = not isinstance(line, str)
             if not line:
                 continue
+            if needs_decode:
+                line = line.decode(self.catalog.charset)
             if line[0] == '#':
                 if line[:2] == '#~':
                     self._process_message_line(lineno, line[2:].lstrip(), obsolete=True)
index 2bcc3df8d9cadacbf16a19941114b26ecaae195d..ffc95295c493d572d14071c23621245f9639cc83 100644 (file)
@@ -1068,11 +1068,18 @@ def test_iterable_of_strings():
     """
     Test we can parse from an iterable of strings.
     """
-    catalog = pofile.read_po(['msgid "foo"', b'msgstr "Voh"'], locale="en_US")
+    catalog = pofile.read_po(['msgid "foo"', 'msgstr "Voh"'], locale="en_US")
     assert catalog.locale == Locale("en", "US")
     assert catalog.get("foo").string == "Voh"
 
 
+@pytest.mark.parametrize("order", [1, -1])
+def test_iterable_of_mismatching_strings(order):
+    # Mixing and matching byteses and strs in the same read_po call is not allowed.
+    with pytest.raises(Exception):  # noqa: B017 (will raise either TypeError or AttributeError)
+        pofile.read_po(['msgid "foo"', b'msgstr "Voh"'][::order])
+
+
 def test_issue_1087():
     buf = StringIO(r'''
 msgid ""