]> git.ipfire.org Git - thirdparty/babel.git/commitdiff
Optimizations for read_po (#1200) master
authorAarni Koskela <akx@iki.fi>
Sat, 5 Apr 2025 15:21:04 +0000 (18:21 +0300)
committerGitHub <noreply@github.com>
Sat, 5 Apr 2025 15:21:04 +0000 (18:21 +0300)
* Avoid extra casts (`Message()` takes care of those)

* Optimize empty normalized strings

* Don't sort translations unless plural

* Optimize unescape()

* Optimize line processing

* Optimize keyword parsing

* Optimize comment parsing

* Avoid hot `isinstance`ing in PO file parse loop

* Add fast paths in `python_format` and `python_brace_format`

* Inline distincting in `catalog.py`

babel/messages/catalog.py
babel/messages/pofile.py
tests/messages/test_normalized_string.py [deleted file]
tests/messages/test_pofile.py

index de96ea576a7a560db65cd069b4aa01f8863912c7..e1d61e0999c30a41f25f4c2d0a6bd1ef888bf5fe 100644 (file)
@@ -23,7 +23,7 @@ from babel import __version__ as VERSION
 from babel.core import Locale, UnknownLocaleError
 from babel.dates import format_datetime
 from babel.messages.plurals import get_plural
-from babel.util import LOCALTZ, _cmp, distinct
+from babel.util import LOCALTZ, _cmp
 
 if TYPE_CHECKING:
     from typing_extensions import TypeAlias
@@ -164,7 +164,7 @@ class Message:
         if not string and self.pluralizable:
             string = ('', '')
         self.string = string
-        self.locations = list(distinct(locations))
+        self.locations = list(dict.fromkeys(locations)) if locations else []
         self.flags = set(flags)
         if id and self.python_format:
             self.flags.add('python-format')
@@ -174,12 +174,15 @@ class Message:
             self.flags.add('python-brace-format')
         else:
             self.flags.discard('python-brace-format')
-        self.auto_comments = list(distinct(auto_comments))
-        self.user_comments = list(distinct(user_comments))
-        if isinstance(previous_id, str):
-            self.previous_id = [previous_id]
+        self.auto_comments = list(dict.fromkeys(auto_comments)) if auto_comments else []
+        self.user_comments = list(dict.fromkeys(user_comments)) if user_comments else []
+        if previous_id:
+            if isinstance(previous_id, str):
+                self.previous_id = [previous_id]
+            else:
+                self.previous_id = list(previous_id)
         else:
-            self.previous_id = list(previous_id)
+            self.previous_id = []
         self.lineno = lineno
         self.context = context
 
@@ -289,9 +292,12 @@ class Message:
 
         :type:  `bool`"""
         ids = self.id
-        if not isinstance(ids, (list, tuple)):
-            ids = [ids]
-        return any(PYTHON_FORMAT.search(id) for id in ids)
+        if isinstance(ids, (list, tuple)):
+            for id in ids:  # Explicit loop for performance reasons.
+                if PYTHON_FORMAT.search(id):
+                    return True
+            return False
+        return bool(PYTHON_FORMAT.search(ids))
 
     @property
     def python_brace_format(self) -> bool:
@@ -304,9 +310,12 @@ class Message:
 
         :type:  `bool`"""
         ids = self.id
-        if not isinstance(ids, (list, tuple)):
-            ids = [ids]
-        return any(_has_python_brace_format(id) for id in ids)
+        if isinstance(ids, (list, tuple)):
+            for id in ids:  # Explicit loop for performance reasons.
+                if _has_python_brace_format(id):
+                    return True
+            return False
+        return _has_python_brace_format(ids)
 
 
 class TranslationError(Exception):
@@ -729,12 +738,9 @@ class Catalog:
                 # The new message adds pluralization
                 current.id = message.id
                 current.string = message.string
-            current.locations = list(distinct(current.locations +
-                                              message.locations))
-            current.auto_comments = list(distinct(current.auto_comments +
-                                                  message.auto_comments))
-            current.user_comments = list(distinct(current.user_comments +
-                                                  message.user_comments))
+            current.locations = list(dict.fromkeys([*current.locations, *message.locations]))
+            current.auto_comments = list(dict.fromkeys([*current.auto_comments, *message.auto_comments]))
+            current.user_comments = list(dict.fromkeys([*current.user_comments, *message.user_comments]))
             current.flags |= message.flags
         elif id == '':
             # special treatment for the header message
@@ -916,8 +922,8 @@ class Catalog:
                 assert oldmsg is not None
             message.string = oldmsg.string
 
-            if keep_user_comments:
-                message.user_comments = list(distinct(oldmsg.user_comments))
+            if keep_user_comments and oldmsg.user_comments:
+                message.user_comments = list(dict.fromkeys(oldmsg.user_comments))
 
             if isinstance(message.id, (list, tuple)):
                 if not isinstance(message.string, (list, tuple)):
index 3afdd60610c1aebd353b0a7dbc99fde2f884bee0..8220637bcea7a32a073c692a04481f3b571a9ae0 100644 (file)
@@ -17,7 +17,7 @@ from typing import TYPE_CHECKING, Literal
 
 from babel.core import Locale
 from babel.messages.catalog import Catalog, Message
-from babel.util import TextWrapper, _cmp
+from babel.util import TextWrapper
 
 if TYPE_CHECKING:
     from typing import IO, AnyStr
@@ -25,6 +25,9 @@ if TYPE_CHECKING:
     from _typeshed import SupportsWrite
 
 
+_unescape_re = re.compile(r'\\([\\trn"])')
+
+
 def unescape(string: str) -> str:
     r"""Reverse `escape` the given string.
 
@@ -45,7 +48,10 @@ def unescape(string: str) -> str:
             return '\r'
         # m is \ or "
         return m
-    return re.compile(r'\\([\\trn"])').sub(replace_escapes, string[1:-1])
+
+    if "\\" not in string:  # Fast path: there's nothing to unescape
+        return string[1:-1]
+    return _unescape_re.sub(replace_escapes, string[1:-1])
 
 
 def denormalize(string: str) -> str:
@@ -73,7 +79,7 @@ def denormalize(string: str) -> str:
         escaped_lines = string.splitlines()
         if string.startswith('""'):
             escaped_lines = escaped_lines[1:]
-        return ''.join(unescape(line) for line in escaped_lines)
+        return ''.join(map(unescape, escaped_lines))
     else:
         return unescape(string)
 
@@ -132,48 +138,14 @@ class PoFileError(Exception):
         self.lineno = lineno
 
 
-class _NormalizedString:
-
+class _NormalizedString(list):
     def __init__(self, *args: str) -> None:
-        self._strs: list[str] = []
-        for arg in args:
-            self.append(arg)
-
-    def append(self, s: str) -> None:
-        self._strs.append(s.strip())
+        super().__init__(map(str.strip, args))
 
     def denormalize(self) -> str:
-        return ''.join(unescape(s) for s in self._strs)
-
-    def __bool__(self) -> bool:
-        return bool(self._strs)
-
-    def __repr__(self) -> str:
-        return os.linesep.join(self._strs)
-
-    def __cmp__(self, other: object) -> int:
-        if not other:
-            return 1
-
-        return _cmp(str(self), str(other))
-
-    def __gt__(self, other: object) -> bool:
-        return self.__cmp__(other) > 0
-
-    def __lt__(self, other: object) -> bool:
-        return self.__cmp__(other) < 0
-
-    def __ge__(self, other: object) -> bool:
-        return self.__cmp__(other) >= 0
-
-    def __le__(self, other: object) -> bool:
-        return self.__cmp__(other) <= 0
-
-    def __eq__(self, other: object) -> bool:
-        return self.__cmp__(other) == 0
-
-    def __ne__(self, other: object) -> bool:
-        return self.__cmp__(other) != 0
+        if not self:
+            return ""
+        return ''.join(map(unescape, self))
 
 
 class PoFileParser:
@@ -183,13 +155,6 @@ class PoFileParser:
     See `read_po` for simple cases.
     """
 
-    _keywords = [
-        'msgid',
-        'msgstr',
-        'msgctxt',
-        'msgid_plural',
-    ]
-
     def __init__(self, catalog: Catalog, ignore_obsolete: bool = False, abort_invalid: bool = False) -> None:
         self.catalog = catalog
         self.ignore_obsolete = ignore_obsolete
@@ -216,23 +181,20 @@ class PoFileParser:
         Add a message to the catalog based on the current parser state and
         clear the state ready to process the next message.
         """
-        self.translations.sort()
         if len(self.messages) > 1:
             msgid = tuple(m.denormalize() for m in self.messages)
-        else:
-            msgid = self.messages[0].denormalize()
-        if isinstance(msgid, (list, tuple)):
             string = ['' for _ in range(self.catalog.num_plurals)]
-            for idx, translation in self.translations:
+            for idx, translation in sorted(self.translations):
                 if idx >= self.catalog.num_plurals:
                     self._invalid_pofile("", self.offset, "msg has more translations than num_plurals of catalog")
                     continue
                 string[idx] = translation.denormalize()
             string = tuple(string)
         else:
+            msgid = self.messages[0].denormalize()
             string = self.translations[0][1].denormalize()
         msgctxt = self.context.denormalize() if self.context else None
-        message = Message(msgid, string, list(self.locations), set(self.flags),
+        message = Message(msgid, string, self.locations, self.flags,
                           self.auto_comments, self.user_comments, lineno=self.offset + 1,
                           context=msgctxt)
         if self.obsolete:
@@ -247,27 +209,19 @@ class PoFileParser:
         if self.messages:
             if not self.translations:
                 self._invalid_pofile("", self.offset, f"missing msgstr for msgid '{self.messages[0].denormalize()}'")
-                self.translations.append([0, _NormalizedString("")])
+                self.translations.append([0, _NormalizedString()])
             self._add_message()
 
     def _process_message_line(self, lineno, line, obsolete=False) -> None:
-        if line.startswith('"'):
+        if not line:
+            return
+        if line[0] == '"':
             self._process_string_continuation_line(line, lineno)
         else:
             self._process_keyword_line(lineno, line, obsolete)
 
     def _process_keyword_line(self, lineno, line, obsolete=False) -> None:
-
-        for keyword in self._keywords:
-            try:
-                if line.startswith(keyword) and line[len(keyword)] in [' ', '[']:
-                    arg = line[len(keyword):]
-                    break
-            except IndexError:
-                self._invalid_pofile(line, lineno, "Keyword must be followed by a string")
-        else:
-            self._invalid_pofile(line, lineno, "Start of line didn't match any expected keyword.")
-            return
+        keyword, _, arg = line.partition(' ')
 
         if keyword in ['msgid', 'msgctxt']:
             self._finish_current_message()
@@ -283,19 +237,23 @@ class PoFileParser:
             self.in_msgctxt = False
             self.in_msgid = True
             self.messages.append(_NormalizedString(arg))
+            return
 
-        elif keyword == 'msgstr':
+        if keyword == 'msgctxt':
+            self.in_msgctxt = True
+            self.context = _NormalizedString(arg)
+            return
+
+        if keyword == 'msgstr' or keyword.startswith('msgstr['):
             self.in_msgid = False
             self.in_msgstr = True
-            if arg.startswith('['):
-                idx, msg = arg[1:].split(']', 1)
-                self.translations.append([int(idx), _NormalizedString(msg)])
-            else:
-                self.translations.append([0, _NormalizedString(arg)])
+            kwarg, has_bracket, idxarg = keyword.partition('[')
+            idx = int(idxarg[:-1]) if has_bracket else 0
+            s = _NormalizedString(arg) if arg != '""' else _NormalizedString()
+            self.translations.append([idx, s])
+            return
 
-        elif keyword == 'msgctxt':
-            self.in_msgctxt = True
-            self.context = _NormalizedString(arg)
+        self._invalid_pofile(line, lineno, "Unknown or misformatted keyword")
 
     def _process_string_continuation_line(self, line, lineno) -> None:
         if self.in_msgid:
@@ -307,49 +265,62 @@ class PoFileParser:
         else:
             self._invalid_pofile(line, lineno, "Got line starting with \" but not in msgid, msgstr or msgctxt")
             return
-        s.append(line)
+        s.append(line.strip())  # For performance reasons, `NormalizedString` doesn't strip internally
 
     def _process_comment(self, line) -> None:
 
         self._finish_current_message()
 
-        if line[1:].startswith(':'):
+        prefix = line[:2]
+        if prefix == '#:':
             for location in _extract_locations(line[2:]):
-                pos = location.rfind(':')
-                if pos >= 0:
+                a, colon, b = location.rpartition(':')
+                if colon:
                     try:
-                        lineno = int(location[pos + 1:])
+                        self.locations.append((a, int(b)))
                     except ValueError:
                         continue
-                    self.locations.append((location[:pos], lineno))
-                else:
+                else:  # No line number specified
                     self.locations.append((location, None))
-        elif line[1:].startswith(','):
-            for flag in line[2:].lstrip().split(','):
-                self.flags.append(flag.strip())
-        elif line[1:].startswith('.'):
+            return
+
+        if prefix == '#,':
+            self.flags.extend(flag.strip() for flag in line[2:].lstrip().split(','))
+            return
+
+        if prefix == '#.':
             # These are called auto-comments
             comment = line[2:].strip()
             if comment:  # Just check that we're not adding empty comments
                 self.auto_comments.append(comment)
-        else:
-            # These are called user comments
-            self.user_comments.append(line[1:].strip())
+            return
+
+        # These are called user comments
+        self.user_comments.append(line[1:].strip())
 
     def parse(self, fileobj: IO[AnyStr] | Iterable[AnyStr]) -> None:
         """
-        Reads from the file-like object `fileobj` and adds any po file
-        units found in it to the `Catalog` supplied to the constructor.
+        Reads from the file-like object (or iterable of string-likes) `fileobj`
+        and adds any po file units found in it to the `Catalog`
+        supplied to the constructor.
+
+        All of the items in the iterable must be the same type; either `str`
+        or `bytes` (decoded with the catalog charset), but not a mixture.
         """
+        needs_decode = None
 
         for lineno, line in enumerate(fileobj):
             line = line.strip()
-            if not isinstance(line, str):
-                line = line.decode(self.catalog.charset)
+            if needs_decode is None:
+                # If we don't yet know whether we need to decode,
+                # let's find out now.
+                needs_decode = not isinstance(line, str)
             if not line:
                 continue
-            if line.startswith('#'):
-                if line[1:].startswith('~'):
+            if needs_decode:
+                line = line.decode(self.catalog.charset)
+            if line[0] == '#':
+                if line[:2] == '#~':
                     self._process_message_line(lineno, line[2:].lstrip(), obsolete=True)
                 else:
                     try:
@@ -364,8 +335,8 @@ class PoFileParser:
         # No actual messages found, but there was some info in comments, from which
         # we'll construct an empty header message
         if not self.counter and (self.flags or self.user_comments or self.auto_comments):
-            self.messages.append(_NormalizedString('""'))
-            self.translations.append([0, _NormalizedString('""')])
+            self.messages.append(_NormalizedString())
+            self.translations.append([0, _NormalizedString()])
             self._add_message()
 
     def _invalid_pofile(self, line, lineno, msg) -> None:
diff --git a/tests/messages/test_normalized_string.py b/tests/messages/test_normalized_string.py
deleted file mode 100644 (file)
index 9c95672..0000000
+++ /dev/null
@@ -1,17 +0,0 @@
-from babel.messages.pofile import _NormalizedString
-
-
-def test_normalized_string():
-    ab1 = _NormalizedString('a', 'b ')
-    ab2 = _NormalizedString('a', ' b')
-    ac1 = _NormalizedString('a', 'c')
-    ac2 = _NormalizedString('  a', 'c  ')
-    z = _NormalizedString()
-    assert ab1 == ab2 and ac1 == ac2  # __eq__
-    assert ab1 < ac1  # __lt__
-    assert ac1 > ab2  # __gt__
-    assert ac1 >= ac2  # __ge__
-    assert ab1 <= ab2  # __le__
-    assert ab1 != ac1  # __ne__
-    assert not z  # __nonzero__ / __bool__
-    assert sorted([ab1, ab2, ac1])  # the sort order is not stable so we can't really check it, just that we can sort
index 2bcc3df8d9cadacbf16a19941114b26ecaae195d..ffc95295c493d572d14071c23621245f9639cc83 100644 (file)
@@ -1068,11 +1068,18 @@ def test_iterable_of_strings():
     """
     Test we can parse from an iterable of strings.
     """
-    catalog = pofile.read_po(['msgid "foo"', b'msgstr "Voh"'], locale="en_US")
+    catalog = pofile.read_po(['msgid "foo"', 'msgstr "Voh"'], locale="en_US")
     assert catalog.locale == Locale("en", "US")
     assert catalog.get("foo").string == "Voh"
 
 
+@pytest.mark.parametrize("order", [1, -1])
+def test_iterable_of_mismatching_strings(order):
+    # Mixing and matching byteses and strs in the same read_po call is not allowed.
+    with pytest.raises(Exception):  # noqa: B017 (will raise either TypeError or AttributeError)
+        pofile.read_po(['msgid "foo"', b'msgstr "Voh"'][::order])
+
+
 def test_issue_1087():
     buf = StringIO(r'''
 msgid ""