]> git.ipfire.org Git - thirdparty/babel.git/commitdiff
Improve type annotation for `babel.support.Translations.load` (#983)
authorAarni Koskela <akx@iki.fi>
Thu, 2 Mar 2023 14:13:12 +0000 (16:13 +0200)
committerGitHub <noreply@github.com>
Thu, 2 Mar 2023 14:13:12 +0000 (14:13 +0000)
Fixes #982

Co-authored-by: Jonah Lawrence <jonah@freshidea.com>
babel/support.py

index d6ff73aa21173da5a5d0ec3fb1ad4c1a919bccaf..40ce9814d67ad18a48091d53a964c8f16dee18ed 100644 (file)
@@ -17,7 +17,7 @@ import gettext
 import locale
 import os
 from collections.abc import Iterator
-from typing import TYPE_CHECKING, Any, Callable
+from typing import TYPE_CHECKING, Any, Callable, Iterable
 
 from babel.core import Locale
 from babel.dates import format_date, format_datetime, format_time, format_timedelta
@@ -615,7 +615,7 @@ class Translations(NullTranslations, gettext.GNUTranslations):
     def load(
         cls,
         dirname: str | os.PathLike[str] | None = None,
-        locales: list[str] | tuple[str, ...] | str | None = None,
+        locales: Iterable[str | Locale] | str | Locale | None = None,
         domain: str | None = None,
     ) -> NullTranslations:
         """Load translations from the given directory.
@@ -626,13 +626,9 @@ class Translations(NullTranslations, gettext.GNUTranslations):
                         strings)
         :param domain: the message domain (default: 'messages')
         """
-        if locales is not None:
-            if not isinstance(locales, (list, tuple)):
-                locales = [locales]
-            locales = [str(locale) for locale in locales]
         if not domain:
             domain = cls.DEFAULT_DOMAIN
-        filename = gettext.find(domain, dirname, locales)
+        filename = gettext.find(domain, dirname, _locales_to_names(locales))
         if not filename:
             return NullTranslations()
         with open(filename, 'rb') as fp:
@@ -683,3 +679,21 @@ class Translations(NullTranslations, gettext.GNUTranslations):
                 self.files.extend(translations.files)
 
         return self
+
+
+def _locales_to_names(
+    locales: Iterable[str | Locale] | str | Locale | None,
+) -> list[str] | None:
+    """Normalize a `locales` argument to a list of locale names.
+
+    :param locales: the list of locales in order of preference (items in
+                    this list can be either `Locale` objects or locale
+                    strings)
+    """
+    if locales is None:
+        return None
+    if isinstance(locales, Locale):
+        return [str(locale)]
+    if isinstance(locales, str):
+        return [locales]
+    return [str(locale) for locale in locales]