]> git.ipfire.org Git - thirdparty/babel.git/commitdiff
Refactor decimal handling code.
authorKevin Deldycke <kevin@deldycke.com>
Fri, 7 Apr 2017 14:09:14 +0000 (16:09 +0200)
committerIsaac Jurado <diptongo@gmail.com>
Tue, 17 Oct 2017 22:04:25 +0000 (00:04 +0200)
babel/numbers.py
tests/test_numbers.py

index 8728699fb23b4f514f838b2a5950d525addf4585..036513217f730d71c51cb84d4203434c76d01b28 100644 (file)
 import re
 from datetime import date as date_, datetime as datetime_
 from itertools import chain
+import warnings
+from itertools import chain
 
 from babel.core import default_locale, Locale, get_global
 from babel._compat import decimal, string_types
 from babel.localedata import locale_identifiers
 
+try:
+    long
+except NameError:
+    long = int
+
 
 LC_NUMERIC = default_locale('LC_NUMERIC')
 
@@ -304,14 +311,25 @@ def format_number(number, locale=LC_NUMERIC):
     >>> format_number(1099, locale='de_DE')
     u'1.099'
 
+    .. deprecated:: 2.6.0
+
+       Use babel.numbers.format_decimal() instead.
 
     :param number: the number to format
     :param locale: the `Locale` object or locale identifier
+
+
     """
-    # Do we really need this one?
+    warnings.warn('Use babel.numbers.format_decimal() instead.', DeprecationWarning)
     return format_decimal(number, locale=locale)
 
 
+def get_decimal_quantum(precision):
+    """Return minimal quantum of a number, as defined by precision."""
+    assert isinstance(precision, (int, long, decimal.Decimal))
+    return decimal.Decimal(10) ** (-precision)
+
+
 def format_decimal(number, format=None, locale=LC_NUMERIC):
     u"""Return the given decimal number formatted for a specific locale.
 
@@ -412,14 +430,11 @@ def format_currency(number, currency, format=None, locale=LC_NUMERIC,
         try:
             pattern = locale.currency_formats[format_type]
         except KeyError:
-            raise UnknownCurrencyFormatError("%r is not a known currency format"
-                                             " type" % format_type)
-    if currency_digits:
-        precision = get_currency_precision(currency)
-        frac = (precision, precision)
-    else:
-        frac = None
-    return pattern.apply(number, locale, currency=currency, force_frac=frac)
+            raise UnknownCurrencyFormatError(
+                "%r is not a known currency format type" % format_type)
+
+    return pattern.apply(
+        number, locale, currency=currency, currency_digits=currency_digits)
 
 
 def format_percent(number, format=None, locale=LC_NUMERIC):
@@ -456,7 +471,7 @@ def format_scientific(number, format=None, locale=LC_NUMERIC):
 
     The format pattern can also be specified explicitly:
 
-    >>> format_scientific(1234567, u'##0E00', locale='en_US')
+    >>> format_scientific(1234567, u'##0.##E00', locale='en_US')
     u'1.23E06'
 
     :param number: the number to format
@@ -615,7 +630,6 @@ def parse_pattern(pattern):
     int_prec = parse_precision(integer)
     frac_prec = parse_precision(fraction)
     if exp:
-        frac_prec = parse_precision(integer + fraction)
         exp_plus = exp.startswith('+')
         exp = exp.lstrip('+')
         exp_prec = parse_precision(exp)
@@ -633,6 +647,7 @@ class NumberPattern(object):
 
     def __init__(self, pattern, prefix, suffix, grouping,
                  int_prec, frac_prec, exp_prec, exp_plus):
+        # Metadata of the decomposed parsed pattern.
         self.pattern = pattern
         self.prefix = prefix
         self.suffix = suffix
@@ -641,68 +656,108 @@ class NumberPattern(object):
         self.frac_prec = frac_prec
         self.exp_prec = exp_prec
         self.exp_plus = exp_plus
-        if '%' in ''.join(self.prefix + self.suffix):
-            self.scale = 2
-        elif u'‰' in ''.join(self.prefix + self.suffix):
-            self.scale = 3
-        else:
-            self.scale = 0
+        self.scale = self.compute_scale()
 
     def __repr__(self):
         return '<%s %r>' % (type(self).__name__, self.pattern)
 
-    def apply(self, value, locale, currency=None, force_frac=None):
-        frac_prec = force_frac or self.frac_prec
+    def compute_scale(self):
+        """Return the scaling factor to apply to the number before rendering.
+
+        Auto-set to a factor of 2 or 3 if presence of a ``%`` or ``‰`` sign is
+        detected in the prefix or suffix of the pattern. Default is to not mess
+        with the scale at all and keep it to 0.
+        """
+        scale = 0
+        if '%' in ''.join(self.prefix + self.suffix):
+            scale = 2
+        elif u'‰' in ''.join(self.prefix + self.suffix):
+            scale = 3
+        return scale
+
+    def scientific_notation_elements(self, value, locale):
+        """ Returns normalized scientific notation components of a value.
+        """
+        # Normalize value to only have one lead digit.
+        exp = value.adjusted()
+        value = value * get_decimal_quantum(exp)
+        assert value.adjusted() == 0
+
+        # Shift exponent and value by the minimum number of leading digits
+        # imposed by the rendering pattern. And always make that number
+        # greater or equal to 1.
+        lead_shift = max([1, min(self.int_prec)]) - 1
+        exp = exp - lead_shift
+        value = value * get_decimal_quantum(-lead_shift)
+
+        # Get exponent sign symbol.
+        exp_sign = ''
+        if exp < 0:
+            exp_sign = get_minus_sign_symbol(locale)
+        elif self.exp_plus:
+            exp_sign = get_plus_sign_symbol(locale)
+
+        # Normalize exponent value now that we have the sign.
+        exp = abs(exp)
+
+        return value, exp, exp_sign
+
+    def apply(self, value, locale, currency=None, currency_digits=True):
+        """Renders into a string a number following the defined pattern.
+        """
         if not isinstance(value, decimal.Decimal):
             value = decimal.Decimal(str(value))
+
         value = value.scaleb(self.scale)
+
+        # Separate the absolute value from its sign.
         is_negative = int(value.is_signed())
-        if self.exp_prec:  # Scientific notation
-            exp = value.adjusted()
-            value = abs(value)
-            # Minimum number of integer digits
-            if self.int_prec[0] == self.int_prec[1]:
-                exp -= self.int_prec[0] - 1
-            # Exponent grouping
-            elif self.int_prec[1]:
-                exp = int(exp / self.int_prec[1]) * self.int_prec[1]
-            if exp < 0:
-                value = value * 10**(-exp)
-            else:
-                value = value / 10**exp
-            exp_sign = ''
-            if exp < 0:
-                exp_sign = get_minus_sign_symbol(locale)
-            elif self.exp_plus:
-                exp_sign = get_plus_sign_symbol(locale)
-            exp = abs(exp)
-            number = u'%s%s%s%s' % \
-                (self._format_significant(value, frac_prec[0], frac_prec[1]),
-                 get_exponential_symbol(locale), exp_sign,
-                 self._format_int(str(exp), self.exp_prec[0],
-                                  self.exp_prec[1], locale))
-        elif '@' in self.pattern:  # Is it a siginificant digits pattern?
-            text = self._format_significant(abs(value),
+        value = abs(value).normalize()
+
+        # Prepare scientific notation metadata.
+        if self.exp_prec:
+            value, exp, exp_sign = self.scientific_notation_elements(value, locale)
+
+        # Adjust the precision of the fractionnal part and force it to the
+        # currency's if neccessary.
+        frac_prec = self.frac_prec
+        if currency and currency_digits:
+            frac_prec = (get_currency_precision(currency), ) * 2
+
+        # Render scientific notation.
+        if self.exp_prec:
+            number = ''.join([
+                self._quantize_value(value, locale, frac_prec),
+                get_exponential_symbol(locale),
+                exp_sign,
+                self._format_int(
+                    str(exp), self.exp_prec[0], self.exp_prec[1], locale)])
+
+        # Is it a siginificant digits pattern?
+        elif '@' in self.pattern:
+            text = self._format_significant(value,
                                             self.int_prec[0],
                                             self.int_prec[1])
             a, sep, b = text.partition(".")
             number = self._format_int(a, 0, 1000, locale)
             if sep:
                 number += get_decimal_symbol(locale) + b
-        else:  # A normal number pattern
-            precision = decimal.Decimal('1.' + '1' * frac_prec[1])
-            rounded = value.quantize(precision)
-            a, sep, b = str(abs(rounded)).partition(".")
-            number = (self._format_int(a, self.int_prec[0],
-                                       self.int_prec[1], locale) +
-                      self._format_frac(b or '0', locale, force_frac))
-        retval = u'%s%s%s' % (self.prefix[is_negative], number,
-                              self.suffix[is_negative])
+
+        # A normal number pattern.
+        else:
+            number = self._quantize_value(value, locale, frac_prec)
+
+        retval = ''.join([
+            self.prefix[is_negative],
+            number,
+            self.suffix[is_negative]])
+
         if u'¤' in retval:
             retval = retval.replace(u'¤¤¤',
                                     get_currency_name(currency, value, locale))
             retval = retval.replace(u'¤¤', currency.upper())
             retval = retval.replace(u'¤', get_currency_symbol(currency, locale))
+
         return retval
 
     #
@@ -757,6 +812,15 @@ class NumberPattern(object):
             gsize = self.grouping[1]
         return value + ret
 
+    def _quantize_value(self, value, locale, frac_prec):
+        quantum = get_decimal_quantum(frac_prec[1])
+        rounded = value.quantize(quantum)
+        a, sep, b = str(rounded).partition(".")
+        number = (self._format_int(a, self.int_prec[0],
+                                   self.int_prec[1], locale) +
+                  self._format_frac(b or '0', locale, frac_prec))
+        return number
+
     def _format_frac(self, value, locale, force_frac=None):
         min, max = force_frac or self.frac_prec
         if len(value) < min:
index 5bcd1717d1b6cfd40cb5a85aafb26a034a9a6847..5c8da3422d0f0339f5dbb22c1dceaf79118d9a3a 100644 (file)
@@ -124,7 +124,7 @@ class FormatDecimalTestCase(unittest.TestCase):
         self.assertEqual(fmt, '1.2E3')
         # Exponent grouping
         fmt = numbers.format_scientific(12345, '##0.####E0', locale='en_US')
-        self.assertEqual(fmt, '12.345E3')
+        self.assertEqual(fmt, '1.2345E4')
         # Minimum number of int digits
         fmt = numbers.format_scientific(12345, '00.###E0', locale='en_US')
         self.assertEqual(fmt, '12.345E3')
@@ -283,11 +283,45 @@ def test_format_decimal():
     assert numbers.format_decimal(1.2345, locale='sv_SE') == u'1,234'
     assert numbers.format_decimal(1.2345, locale='de') == u'1,234'
     assert numbers.format_decimal(12345.5, locale='en_US') == u'12,345.5'
+    assert numbers.format_decimal(0001.2345000, locale='en_US') == u'1.234'
+    assert numbers.format_decimal(-0001.2346000, locale='en_US') == u'-1.235'
+    assert numbers.format_decimal(0000000.5, locale='en_US') == u'0.5'
+    assert numbers.format_decimal(000, locale='en_US') == u'0'
+
+
+@pytest.mark.parametrize('input_value, expected_value', [
+    ('10000', '10,000'),
+    ('1', '1'),
+    ('1.0', '1'),
+    ('1.1', '1.1'),
+    ('1.11', '1.11'),
+    ('1.110', '1.11'),
+    ('1.001', '1.001'),
+    ('1.00100', '1.001'),
+    ('01.00100', '1.001'),
+    ('101.00100', '101.001'),
+    ('00000', '0'),
+    ('0', '0'),
+    ('0.0', '0'),
+    ('0.1', '0.1'),
+    ('0.11', '0.11'),
+    ('0.110', '0.11'),
+    ('0.001', '0.001'),
+    ('0.00100', '0.001'),
+    ('00.00100', '0.001'),
+    ('000.00100', '0.001'),
+])
+def test_format_decimal_precision(input_value, expected_value):
+    # Test precision conservation.
+    assert numbers.format_decimal(
+        decimal.Decimal(input_value), locale='en_US') == expected_value
 
 
 def test_format_currency():
     assert (numbers.format_currency(1099.98, 'USD', locale='en_US')
             == u'$1,099.98')
+    assert (numbers.format_currency(0, 'USD', locale='en_US')
+            == u'$0.00')
     assert (numbers.format_currency(1099.98, 'USD', locale='es_CO')
             == u'US$\xa01.099,98')
     assert (numbers.format_currency(1099.98, 'EUR', locale='de_DE')
@@ -306,10 +340,16 @@ def test_format_currency_format_type():
     assert (numbers.format_currency(1099.98, 'USD', locale='en_US',
                                     format_type="standard")
             == u'$1,099.98')
+    assert (numbers.format_currency(0, 'USD', locale='en_US',
+                                    format_type="standard")
+            == u'$0.00')
 
     assert (numbers.format_currency(1099.98, 'USD', locale='en_US',
                                     format_type="accounting")
             == u'$1,099.98')
+    assert (numbers.format_currency(0, 'USD', locale='en_US',
+                                    format_type="accounting")
+            == u'$0.00')
 
     with pytest.raises(numbers.UnknownCurrencyFormatError) as excinfo:
         numbers.format_currency(1099.98, 'USD', locale='en_US',
@@ -328,8 +368,37 @@ def test_format_currency_format_type():
             == u'1.099,98')
 
 
+@pytest.mark.parametrize('input_value, expected_value', [
+    ('10000', '$10,000.00'),
+    ('1', '$1.00'),
+    ('1.0', '$1.00'),
+    ('1.1', '$1.10'),
+    ('1.11', '$1.11'),
+    ('1.110', '$1.11'),
+    ('1.001', '$1.00'),
+    ('1.00100', '$1.00'),
+    ('01.00100', '$1.00'),
+    ('101.00100', '$101.00'),
+    ('00000', '$0.00'),
+    ('0', '$0.00'),
+    ('0.0', '$0.00'),
+    ('0.1', '$0.10'),
+    ('0.11', '$0.11'),
+    ('0.110', '$0.11'),
+    ('0.001', '$0.00'),
+    ('0.00100', '$0.00'),
+    ('00.00100', '$0.00'),
+    ('000.00100', '$0.00'),
+])
+def test_format_currency_precision(input_value, expected_value):
+    # Test precision conservation.
+    assert numbers.format_currency(
+        decimal.Decimal(input_value), 'USD', locale='en_US') == expected_value
+
+
 def test_format_percent():
     assert numbers.format_percent(0.34, locale='en_US') == u'34%'
+    assert numbers.format_percent(0, locale='en_US') == u'0%'
     assert numbers.format_percent(0.34, u'##0%', locale='en_US') == u'34%'
     assert numbers.format_percent(34, u'##0', locale='en_US') == u'34'
     assert numbers.format_percent(25.1234, locale='en_US') == u'2,512%'
@@ -339,14 +408,81 @@ def test_format_percent():
             == u'25,123\u2030')
 
 
-def test_scientific_exponent_displayed_as_integer():
-    assert numbers.format_scientific(100000, locale='en_US') == u'1E5'
+@pytest.mark.parametrize('input_value, expected_value', [
+    ('100', '10,000%'),
+    ('0.01', '1%'),
+    ('0.010', '1%'),
+    ('0.011', '1%'),
+    ('0.0111', '1%'),
+    ('0.01110', '1%'),
+    ('0.01001', '1%'),
+    ('0.0100100', '1%'),
+    ('0.010100100', '1%'),
+    ('0.000000', '0%'),
+    ('0', '0%'),
+    ('0.00', '0%'),
+    ('0.01', '1%'),
+    ('0.011', '1%'),
+    ('0.0110', '1%'),
+    ('0.0001', '0%'),
+    ('0.000100', '0%'),
+    ('0.0000100', '0%'),
+    ('0.00000100', '0%'),
+])
+def test_format_percent_precision(input_value, expected_value):
+    # Test precision conservation.
+    assert numbers.format_percent(
+        decimal.Decimal(input_value), locale='en_US') == expected_value
 
 
 def test_format_scientific():
     assert numbers.format_scientific(10000, locale='en_US') == u'1E4'
-    assert (numbers.format_scientific(1234567, u'##0E00', locale='en_US')
-            == u'1.23E06')
+    assert numbers.format_scientific(4234567, u'#.#E0', locale='en_US') == u'4.2E6'
+    assert numbers.format_scientific(4234567, u'0E0000', locale='en_US') == u'4E0006'
+    assert numbers.format_scientific(4234567, u'##0E00', locale='en_US') == u'4E06'
+    assert numbers.format_scientific(4234567, u'##00E00', locale='en_US') == u'42E05'
+    assert numbers.format_scientific(4234567, u'0,000E00', locale='en_US') == u'4,235E03'
+    assert numbers.format_scientific(4234567, u'##0.#####E00', locale='en_US') == u'4.23457E06'
+    assert numbers.format_scientific(4234567, u'##0.##E00', locale='en_US') == u'4.23E06'
+    assert numbers.format_scientific(42, u'00000.000000E0000', locale='en_US') == u'42000.000000E-0003'
+
+
+def test_default_scientific_format():
+    """ Check the scientific format method auto-correct the rendering pattern
+    in case of a missing fractional part.
+    """
+    assert numbers.format_scientific(12345, locale='en_US') == u'1E4'
+    assert numbers.format_scientific(12345.678, locale='en_US') == u'1E4'
+    assert numbers.format_scientific(12345, u'#E0', locale='en_US') == u'1E4'
+    assert numbers.format_scientific(12345.678, u'#E0', locale='en_US') == u'1E4'
+
+
+@pytest.mark.parametrize('input_value, expected_value', [
+    ('10000', '1E4'),
+    ('1', '1E0'),
+    ('1.0', '1E0'),
+    ('1.1', '1E0'),
+    ('1.11', '1E0'),
+    ('1.110', '1E0'),
+    ('1.001', '1E0'),
+    ('1.00100', '1E0'),
+    ('01.00100', '1E0'),
+    ('101.00100', '1E2'),
+    ('00000', '0E0'),
+    ('0', '0E0'),
+    ('0.0', '0E0'),
+    ('0.1', '1E-1'),
+    ('0.11', '1E-1'),
+    ('0.110', '1E-1'),
+    ('0.001', '1E-3'),
+    ('0.00100', '1E-3'),
+    ('00.00100', '1E-3'),
+    ('000.00100', '1E-3'),
+])
+def test_format_scientific_precision(input_value, expected_value):
+    # Test precision conservation.
+    assert numbers.format_scientific(
+        decimal.Decimal(input_value), locale='en_US') == expected_value
 
 
 def test_parse_number():