]> git.ipfire.org Git - thirdparty/babel.git/commitdiff
Numbers and core type fixes
authorJonah Lawrence <jonah@freshidea.com>
Fri, 3 Feb 2023 15:35:09 +0000 (08:35 -0700)
committerJonah Lawrence <jonah@freshidea.com>
Fri, 3 Feb 2023 15:35:09 +0000 (08:35 -0700)
babel/core.py
babel/numbers.py

index bdd176acad9f059857408709a9e2de92acc47cfd..56f9e417112d54483e56c46e39802fa2f19ecb38 100644 (file)
@@ -13,7 +13,7 @@ from __future__ import annotations
 import os
 import pickle
 from collections.abc import Iterable, Mapping
-from typing import TYPE_CHECKING, Any, overload
+from typing import TYPE_CHECKING, Any
 
 from babel import localedata
 from babel.plural import PluralRule
@@ -260,21 +260,13 @@ class Locale:
         if identifier:
             return Locale.parse(identifier, sep=sep)
 
-    @overload
-    @classmethod
-    def parse(cls, identifier: None, sep: str = ..., resolve_likely_subtags: bool = ...) -> None: ...
-
-    @overload
-    @classmethod
-    def parse(cls, identifier: str | Locale, sep: str = ..., resolve_likely_subtags: bool = ...) -> Locale: ...
-
     @classmethod
     def parse(
         cls,
         identifier: str | Locale | None,
         sep: str = '_',
         resolve_likely_subtags: bool = True,
-    ) -> Locale | None:
+    ) -> Locale:
         """Create a `Locale` instance for the given locale identifier.
 
         >>> l = Locale.parse('de-DE', sep='-')
@@ -317,10 +309,9 @@ class Locale:
                              identifier
         :raise `UnknownLocaleError`: if no locale data is available for the
                                      requested locale
+        :raise `TypeError`: if the identifier is not a string or a `Locale`
         """
-        if identifier is None:
-            return None
-        elif isinstance(identifier, Locale):
+        if isinstance(identifier, Locale):
             return identifier
         elif not isinstance(identifier, str):
             raise TypeError(f"Unexpected value for identifier: {identifier!r}")
@@ -364,9 +355,9 @@ class Locale:
             language, territory, script, variant = parts
             modifier = None
         language = get_global('language_aliases').get(language, language)
-        territory = get_global('territory_aliases').get(territory, (territory,))[0]
-        script = get_global('script_aliases').get(script, script)
-        variant = get_global('variant_aliases').get(variant, variant)
+        territory = get_global('territory_aliases').get(territory or '', (territory,))[0]
+        script = get_global('script_aliases').get(script or '', script)
+        variant = get_global('variant_aliases').get(variant or '', variant)
 
         if territory == 'ZZ':
             territory = None
@@ -389,9 +380,9 @@ class Locale:
         if likely_subtag is not None:
             parts2 = parse_locale(likely_subtag)
             if len(parts2) == 5:
-                language2, _, script2, variant2, modifier2 = parse_locale(likely_subtag)
+                language2, _, script2, variant2, modifier2 = parts2
             else:
-                language2, _, script2, variant2 = parse_locale(likely_subtag)
+                language2, _, script2, variant2 = parts2
                 modifier2 = None
             locale = _try_load_reducing((language2, territory, script2, variant2, modifier2))
             if locale is not None:
@@ -1147,7 +1138,7 @@ def negotiate_locale(preferred: Iterable[str], available: Iterable[str], sep: st
 def parse_locale(
     identifier: str,
     sep: str = '_'
-) -> tuple[str, str | None, str | None, str | None, str | None]:
+) -> tuple[str, str | None, str | None, str | None] | tuple[str, str | None, str | None, str | None, str | None]:
     """Parse a locale identifier into a tuple of the form ``(language,
     territory, script, variant, modifier)``.
 
@@ -1261,7 +1252,7 @@ def get_locale_identifier(
     :param tup: the tuple as returned by :func:`parse_locale`.
     :param sep: the separator for the identifier.
     """
-    tup = tuple(tup[:5])
+    tup = tuple(tup[:5])  # type: ignore  # length should be no more than 5
     lang, territory, script, variant, modifier = tup + (None,) * (5 - len(tup))
     ret = sep.join(filter(None, (lang, script, territory, variant)))
     return f'{ret}@{modifier}' if modifier else ret
index 59acee21229308b646592b44e64f25657a39cc18..1a86d9e6250f49352f97407173065f0f44be6665 100644 (file)
@@ -23,7 +23,7 @@ import datetime
 import decimal
 import re
 import warnings
-from typing import TYPE_CHECKING, Any, overload
+from typing import TYPE_CHECKING, Any, cast, overload
 
 from babel.core import Locale, default_locale, get_global
 from babel.localedata import LocaleDataDict
@@ -428,7 +428,7 @@ def get_decimal_quantum(precision: int | decimal.Decimal) -> decimal.Decimal:
 
 def format_decimal(
     number: float | decimal.Decimal | str,
-    format: str | None = None,
+    format: str | NumberPattern | None = None,
     locale: Locale | str | None = LC_NUMERIC,
     decimal_quantization: bool = True,
     group_separator: bool = True,
@@ -474,8 +474,8 @@ def format_decimal(
                             number format.
     """
     locale = Locale.parse(locale)
-    if not format:
-        format = locale.decimal_formats.get(format)
+    if format is None:
+        format = locale.decimal_formats[format]
     pattern = parse_pattern(format)
     return pattern.apply(
         number, locale, decimal_quantization=decimal_quantization, group_separator=group_separator)
@@ -513,7 +513,7 @@ def format_compact_decimal(
     number, format = _get_compact_format(number, compact_format, locale, fraction_digits)
     # Did not find a format, fall back.
     if format is None:
-        format = locale.decimal_formats.get(None)
+        format = locale.decimal_formats[None]
     pattern = parse_pattern(format)
     return pattern.apply(number, locale, decimal_quantization=False)
 
@@ -521,7 +521,7 @@ def format_compact_decimal(
 def _get_compact_format(
     number: float | decimal.Decimal | str,
     compact_format: LocaleDataDict,
-    locale: Locale | str | None,
+    locale: Locale,
     fraction_digits: int,
 ) -> tuple[decimal.Decimal, NumberPattern | None]:
     """Returns the number after dividing by the unit and the format pattern to use.
@@ -543,7 +543,7 @@ def _get_compact_format(
                 break
             # otherwise, we need to divide the number by the magnitude but remove zeros
             # equal to the number of 0's in the pattern minus 1
-            number = number / (magnitude // (10 ** (pattern.count("0") - 1)))
+            number = cast(decimal.Decimal, number / (magnitude // (10 ** (pattern.count("0") - 1))))
             # round to the number of fraction digits requested
             rounded = round(number, fraction_digits)
             # if the remaining number is singular, use the singular format
@@ -565,7 +565,7 @@ class UnknownCurrencyFormatError(KeyError):
 def format_currency(
     number: float | decimal.Decimal | str,
     currency: str,
-    format: str | None = None,
+    format: str | NumberPattern | None = None,
     locale: Locale | str | None = LC_NUMERIC,
     currency_digits: bool = True,
     format_type: Literal["name", "standard", "accounting"] = "standard",
@@ -680,7 +680,7 @@ def format_currency(
 def _format_currency_long_name(
     number: float | decimal.Decimal | str,
     currency: str,
-    format: str | None = None,
+    format: str | NumberPattern | None = None,
     locale: Locale | str | None = LC_NUMERIC,
     currency_digits: bool = True,
     format_type: Literal["name", "standard", "accounting"] = "standard",
@@ -706,7 +706,7 @@ def _format_currency_long_name(
 
     # Step 5.
     if not format:
-        format = locale.decimal_formats.get(format)
+        format = locale.decimal_formats[format]
 
     pattern = parse_pattern(format)
 
@@ -758,13 +758,15 @@ def format_compact_currency(
             # compress adjacent spaces into one
             format = re.sub(r'(\s)\s+', r'\1', format).strip()
             break
+    if format is None:
+        raise ValueError('No compact currency format found for the given number and locale.')
     pattern = parse_pattern(format)
     return pattern.apply(number, locale, currency=currency, currency_digits=False, decimal_quantization=False)
 
 
 def format_percent(
     number: float | decimal.Decimal | str,
-    format: str | None = None,
+    format: str | NumberPattern | None = None,
     locale: Locale | str | None = LC_NUMERIC,
     decimal_quantization: bool = True,
     group_separator: bool = True,
@@ -808,7 +810,7 @@ def format_percent(
     """
     locale = Locale.parse(locale)
     if not format:
-        format = locale.percent_formats.get(format)
+        format = locale.percent_formats[format]
     pattern = parse_pattern(format)
     return pattern.apply(
         number, locale, decimal_quantization=decimal_quantization, group_separator=group_separator)
@@ -816,7 +818,7 @@ def format_percent(
 
 def format_scientific(
         number: float | decimal.Decimal | str,
-        format: str | None = None,
+        format: str | NumberPattern | None = None,
         locale: Locale | str | None = LC_NUMERIC,
         decimal_quantization: bool = True,
 ) -> str:
@@ -847,7 +849,7 @@ def format_scientific(
     """
     locale = Locale.parse(locale)
     if not format:
-        format = locale.scientific_formats.get(format)
+        format = locale.scientific_formats[format]
     pattern = parse_pattern(format)
     return pattern.apply(
         number, locale, decimal_quantization=decimal_quantization)
@@ -856,7 +858,7 @@ def format_scientific(
 class NumberFormatError(ValueError):
     """Exception raised when a string cannot be parsed into a number."""
 
-    def __init__(self, message: str, suggestions: str | None = None) -> None:
+    def __init__(self, message: str, suggestions: list[str] | None = None) -> None:
         super().__init__(message)
         #: a list of properly formatted numbers derived from the invalid input
         self.suggestions = suggestions
@@ -1140,7 +1142,7 @@ class NumberPattern:
 
     def apply(
         self,
-        value: float | decimal.Decimal,
+        value: float | decimal.Decimal | str,
         locale: Locale | str | None,
         currency: str | None = None,
         currency_digits: bool = True,
@@ -1211,9 +1213,9 @@ class NumberPattern:
             number = ''.join([
                 self._quantize_value(value, locale, frac_prec, group_separator),
                 get_exponential_symbol(locale),
-                exp_sign,
-                self._format_int(
-                    str(exp), self.exp_prec[0], self.exp_prec[1], locale)])
+                exp_sign,  # type: ignore  # exp_sign is always defined here
+                self._format_int(str(exp), self.exp_prec[0], self.exp_prec[1], locale)  # type: ignore  # exp is always defined here
+            ])
 
         # Is it a significant digits pattern?
         elif '@' in self.pattern:
@@ -1234,9 +1236,8 @@ class NumberPattern:
             number if self.number_pattern != '' else '',
             self.suffix[is_negative]])
 
-        if '¤' in retval:
-            retval = retval.replace('¤¤¤',
-                                    get_currency_name(currency, value, locale))
+        if '¤' in retval and currency is not None:
+            retval = retval.replace('¤¤¤', get_currency_name(currency, value, locale))
             retval = retval.replace('¤¤', currency.upper())
             retval = retval.replace('¤', get_currency_symbol(currency, locale))