]> git.ipfire.org Git - thirdparty/babel.git/commitdiff
Improve .po IO (#1068)
authorAarni Koskela <akx@iki.fi>
Wed, 27 Mar 2024 15:46:53 +0000 (17:46 +0200)
committerGitHub <noreply@github.com>
Wed, 27 Mar 2024 15:46:53 +0000 (17:46 +0200)
* read_po: note interface also supports iterable-of-strings, not a filelike

* write_po: refactor into generate_po

babel/messages/pofile.py
tests/messages/test_pofile.py

index b64a5085e81104498a7894c7624e90f596b420e4..6e9dbdf0314601d00c0ce85b3875500ee29ab927 100644 (file)
@@ -291,7 +291,7 @@ class PoFileParser:
             # These are called user comments
             self.user_comments.append(line[1:].strip())
 
-    def parse(self, fileobj: IO[AnyStr]) -> None:
+    def parse(self, fileobj: IO[AnyStr] | Iterable[AnyStr]) -> None:
         """
         Reads from the file-like object `fileobj` and adds any po file
         units found in it to the `Catalog` supplied to the constructor.
@@ -329,7 +329,7 @@ class PoFileParser:
 
 
 def read_po(
-    fileobj: IO[AnyStr],
+    fileobj: IO[AnyStr] | Iterable[AnyStr],
     locale: str | Locale | None = None,
     domain: str | None = None,
     ignore_obsolete: bool = False,
@@ -337,7 +337,7 @@ def read_po(
     abort_invalid: bool = False,
 ) -> Catalog:
     """Read messages from a ``gettext`` PO (portable object) file from the given
-    file-like object and return a `Catalog`.
+    file-like object (or an iterable of lines) and return a `Catalog`.
 
     >>> from datetime import datetime
     >>> from io import StringIO
@@ -373,7 +373,7 @@ def read_po(
     .. versionadded:: 1.0
        Added support for explicit charset argument.
 
-    :param fileobj: the file-like object to read the PO file from
+    :param fileobj: the file-like object (or iterable of lines) to read the PO file from
     :param locale: the locale identifier or `Locale` object, or `None`
                    if the catalog is not bound to a locale (which basically
                    means it's a template)
@@ -529,45 +529,69 @@ def write_po(
                              updating the catalog
     :param include_lineno: include line number in the location comment
     """
-    def _normalize(key, prefix=''):
-        return normalize(key, prefix=prefix, width=width)
-
-    def _write(text):
-        if isinstance(text, str):
-            text = text.encode(catalog.charset, 'backslashreplace')
-        fileobj.write(text)
-
-    def _write_comment(comment, prefix=''):
-        # xgettext always wraps comments even if --no-wrap is passed;
-        # provide the same behaviour
-        _width = width if width and width > 0 else 76
-        for line in wraptext(comment, _width):
-            _write(f"#{prefix} {line.strip()}\n")
-
-    def _write_message(message, prefix=''):
+
+    sort_by = None
+    if sort_output:
+        sort_by = "message"
+    elif sort_by_file:
+        sort_by = "location"
+
+    for line in generate_po(
+        catalog,
+        ignore_obsolete=ignore_obsolete,
+        include_lineno=include_lineno,
+        include_previous=include_previous,
+        no_location=no_location,
+        omit_header=omit_header,
+        sort_by=sort_by,
+        width=width,
+    ):
+        if isinstance(line, str):
+            line = line.encode(catalog.charset, 'backslashreplace')
+        fileobj.write(line)
+
+
+def generate_po(
+    catalog: Catalog,
+    *,
+    ignore_obsolete: bool = False,
+    include_lineno: bool = True,
+    include_previous: bool = False,
+    no_location: bool = False,
+    omit_header: bool = False,
+    sort_by: Literal["message", "location"] | None = None,
+    width: int = 76,
+) -> Iterable[str]:
+    r"""Yield text strings representing a ``gettext`` PO (portable object) file.
+
+    See `write_po()` for a more detailed description.
+    """
+    # xgettext always wraps comments even if --no-wrap is passed;
+    # provide the same behaviour
+    comment_width = width if width and width > 0 else 76
+
+    def _format_comment(comment, prefix=''):
+        for line in wraptext(comment, comment_width):
+            yield f"#{prefix} {line.strip()}\n"
+
+    def _format_message(message, prefix=''):
         if isinstance(message.id, (list, tuple)):
             if message.context:
-                _write(f"{prefix}msgctxt {_normalize(message.context, prefix)}\n")
-            _write(f"{prefix}msgid {_normalize(message.id[0], prefix)}\n")
-            _write(f"{prefix}msgid_plural {_normalize(message.id[1], prefix)}\n")
+                yield f"{prefix}msgctxt {normalize(message.context, prefix=prefix, width=width)}\n"
+            yield f"{prefix}msgid {normalize(message.id[0], prefix=prefix, width=width)}\n"
+            yield f"{prefix}msgid_plural {normalize(message.id[1], prefix=prefix, width=width)}\n"
 
             for idx in range(catalog.num_plurals):
                 try:
                     string = message.string[idx]
                 except IndexError:
                     string = ''
-                _write(f"{prefix}msgstr[{idx:d}] {_normalize(string, prefix)}\n")
+                yield f"{prefix}msgstr[{idx:d}] {normalize(string, prefix=prefix, width=width)}\n"
         else:
             if message.context:
-                _write(f"{prefix}msgctxt {_normalize(message.context, prefix)}\n")
-            _write(f"{prefix}msgid {_normalize(message.id, prefix)}\n")
-            _write(f"{prefix}msgstr {_normalize(message.string or '', prefix)}\n")
-
-    sort_by = None
-    if sort_output:
-        sort_by = "message"
-    elif sort_by_file:
-        sort_by = "location"
+                yield f"{prefix}msgctxt {normalize(message.context, prefix=prefix, width=width)}\n"
+            yield f"{prefix}msgid {normalize(message.id, prefix=prefix, width=width)}\n"
+            yield f"{prefix}msgstr {normalize(message.string or '', prefix=prefix, width=width)}\n"
 
     for message in _sort_messages(catalog, sort_by=sort_by):
         if not message.id:  # This is the header "message"
@@ -580,12 +604,12 @@ def write_po(
                     lines += wraptext(line, width=width,
                                       subsequent_indent='# ')
                 comment_header = '\n'.join(lines)
-            _write(f"{comment_header}\n")
+            yield f"{comment_header}\n"
 
         for comment in message.user_comments:
-            _write_comment(comment)
+            yield from _format_comment(comment)
         for comment in message.auto_comments:
-            _write_comment(comment, prefix='.')
+            yield from _format_comment(comment, prefix='.')
 
         if not no_location:
             locs = []
@@ -606,22 +630,21 @@ def write_po(
                     location = f"{location}:{lineno:d}"
                 if location not in locs:
                     locs.append(location)
-            _write_comment(' '.join(locs), prefix=':')
+            yield from _format_comment(' '.join(locs), prefix=':')
         if message.flags:
-            _write(f"#{', '.join(['', *sorted(message.flags)])}\n")
+            yield f"#{', '.join(['', *sorted(message.flags)])}\n"
 
         if message.previous_id and include_previous:
-            _write_comment(
-                f'msgid {_normalize(message.previous_id[0])}',
+            yield from _format_comment(
+                f'msgid {normalize(message.previous_id[0], width=width)}',
                 prefix='|',
             )
             if len(message.previous_id) > 1:
-                _write_comment('msgid_plural %s' % _normalize(
-                    message.previous_id[1],
-                ), prefix='|')
+                norm_previous_id = normalize(message.previous_id[1], width=width)
+                yield from _format_comment(f'msgid_plural {norm_previous_id}', prefix='|')
 
-        _write_message(message)
-        _write('\n')
+        yield from _format_message(message)
+        yield '\n'
 
     if not ignore_obsolete:
         for message in _sort_messages(
@@ -629,12 +652,12 @@ def write_po(
             sort_by=sort_by,
         ):
             for comment in message.user_comments:
-                _write_comment(comment)
-            _write_message(message, prefix='#~ ')
-            _write('\n')
+                yield from _format_comment(comment)
+            yield from _format_message(message, prefix='#~ ')
+            yield '\n'
 
 
-def _sort_messages(messages: Iterable[Message], sort_by: Literal["message", "location"]) -> list[Message]:
+def _sort_messages(messages: Iterable[Message], sort_by: Literal["message", "location"] | None) -> list[Message]:
     """
     Sort the given message iterable by the given criteria.
 
index 043f9c8bd15af6d7561e5a5fde5dfadac23a0531..d322857d57e923d8984cc445aa49254491e32b08 100644 (file)
@@ -884,3 +884,12 @@ def test_unknown_language_write():
     buf = BytesIO()
     pofile.write_po(buf, catalog)
     assert 'sr_SP' in buf.getvalue().decode()
+
+
+def test_iterable_of_strings():
+    """
+    Test we can parse from an iterable of strings.
+    """
+    catalog = pofile.read_po(['msgid "foo"', b'msgstr "Voh"'], locale="en_US")
+    assert catalog.locale == Locale("en", "US")
+    assert catalog.get("foo").string == "Voh"