]> git.ipfire.org Git - thirdparty/babel.git/commitdiff
numbers: Implement rounding with Decimal
authorIsaac Jurado <diptongo@gmail.com>
Sun, 4 Oct 2015 18:36:02 +0000 (20:36 +0200)
committerIsaac Jurado <diptongo@gmail.com>
Wed, 14 Oct 2015 17:52:38 +0000 (19:52 +0200)
Drop the old bankersround related code and implement rounding using the decimal
module instead.  This change will enable some other goodies such as: use the
drop-in replacement cdecimal when available, or allow for more rounding
algorithms by exposing one more parameter.

babel/numbers.py
tests/test_numbers.py

index af9413f5742e135ee11432e2c28e2c72c0f023c9..f92c714caed8f5777be292dfe70bb17f310b78e8 100644 (file)
 # TODO:
 #  Padding and rounding increments in pattern:
 #  - http://www.unicode.org/reports/tr35/ (Appendix G.6)
-from decimal import Decimal, InvalidOperation
-import math
 import re
 from datetime import date as date_, datetime as datetime_
+from decimal import Decimal, InvalidOperation, ROUND_HALF_EVEN
 
 from babel.core import default_locale, Locale, get_global
 from babel._compat import range_type
@@ -455,94 +454,6 @@ SUFFIX_PATTERN = r"(?P<suffix>.*)"
 number_re = re.compile(r"%s%s%s" % (PREFIX_PATTERN, NUMBER_PATTERN,
                                     SUFFIX_PATTERN))
 
-def split_number(value):
-    """Convert a number into a (intasstring, fractionasstring) tuple"""
-    if isinstance(value, Decimal):
-        # NB can't just do text = str(value) as str repr of Decimal may be
-        # in scientific notation, e.g. for small numbers.
-
-        sign, digits, exp = value.as_tuple()
-        # build list of digits in reverse order, then reverse+join
-        # as per http://docs.python.org/library/decimal.html#recipes
-        int_part = []
-        frac_part = []
-
-        digits = list(map(str, digits))
-
-        # get figures after decimal point
-        for i in range(-exp):
-            # add digit if available, else 0
-            if digits:
-                frac_part.append(digits.pop())
-            else:
-                frac_part.append('0')
-
-        # add in some zeroes...
-        for i in range(exp):
-            int_part.append('0')
-
-        # and the rest
-        while digits:
-            int_part.append(digits.pop())
-
-        # if < 1, int_part must be set to '0'
-        if len(int_part) == 0:
-            int_part = '0',
-
-        if sign:
-            int_part.append('-')
-
-        return ''.join(reversed(int_part)), ''.join(reversed(frac_part))
-    text = ('%.9f' % value).rstrip('0')
-    if '.' in text:
-        a, b = text.split('.', 1)
-        if b == '0':
-            b = ''
-    else:
-        a, b = text, ''
-    return a, b
-
-
-def bankersround(value, ndigits=0):
-    """Round a number to a given precision.
-
-    Works like round() except that the round-half-even (banker's rounding)
-    algorithm is used instead of round-half-up.
-
-    >>> bankersround(5.5, 0)
-    6.0
-    >>> bankersround(6.5, 0)
-    6.0
-    >>> bankersround(-6.5, 0)
-    -6.0
-    >>> bankersround(1234.0, -2)
-    1200.0
-    """
-    sign = int(value < 0) and -1 or 1
-    value = abs(value)
-    a, b = split_number(value)
-    digits = a + b
-    add = 0
-    i = len(a) + ndigits
-    if i < 0 or i >= len(digits):
-        pass
-    elif digits[i] > '5':
-        add = 1
-    elif digits[i] == '5' and digits[i-1] in '13579':
-        add = 1
-    elif digits[i] == '5':     # previous digit is even
-        # We round up unless all following digits are zero.
-        for j in range_type(i + 1, len(digits)):
-            if digits[j] != '0':
-                add = 1
-                break
-
-    scale = 10**ndigits
-    if isinstance(value, Decimal):
-        return Decimal(int(value * scale + add)) / scale * sign
-    else:
-        return float(int(value * scale + add)) / scale * sign
-
 
 def parse_grouping(p):
     """Parse primary and secondary digit grouping
@@ -645,35 +556,30 @@ class NumberPattern(object):
         self.exp_prec = exp_prec
         self.exp_plus = exp_plus
         if '%' in ''.join(self.prefix + self.suffix):
-            self.scale = 100
+            self.scale = 2
         elif u'‰' in ''.join(self.prefix + self.suffix):
-            self.scale = 1000
+            self.scale = 3
         else:
-            self.scale = 1
+            self.scale = 0
 
     def __repr__(self):
         return '<%s %r>' % (type(self).__name__, self.pattern)
 
     def apply(self, value, locale, currency=None, force_frac=None):
         frac_prec = force_frac or self.frac_prec
-        if isinstance(value, float):
+        if not isinstance(value, Decimal):
             value = Decimal(str(value))
-        value *= self.scale
-        is_negative = int(value < 0)
+        value = value.scaleb(self.scale)
+        is_negative = int(value.is_signed())
         if self.exp_prec: # Scientific notation
+            exp = value.adjusted()
             value = abs(value)
-            if value:
-                exp = int(math.floor(math.log(value, 10)))
-            else:
-                exp = 0
             # Minimum number of integer digits
             if self.int_prec[0] == self.int_prec[1]:
                 exp -= self.int_prec[0] - 1
             # Exponent grouping
             elif self.int_prec[1]:
                 exp = int(exp / self.int_prec[1]) * self.int_prec[1]
-            if not isinstance(value, Decimal):
-                value = float(value)
             if exp < 0:
                 value = value * 10**(-exp)
             else:
@@ -685,29 +591,25 @@ class NumberPattern(object):
                 exp_sign = get_plus_sign_symbol(locale)
             exp = abs(exp)
             number = u'%s%s%s%s' % \
-                 (self._format_sigdig(value, frac_prec[0], frac_prec[1]),
+                 (self._format_significant(value, frac_prec[0], frac_prec[1]),
                   get_exponential_symbol(locale),  exp_sign,
                   self._format_int(str(exp), self.exp_prec[0],
                                    self.exp_prec[1], locale))
         elif '@' in self.pattern: # Is it a siginificant digits pattern?
-            text = self._format_sigdig(abs(value),
-                                      self.int_prec[0],
-                                      self.int_prec[1])
-            if '.' in text:
-                a, b = text.split('.')
-                a = self._format_int(a, 0, 1000, locale)
-                if b:
-                    b = get_decimal_symbol(locale) + b
-                number = a + b
-            else:
-                number = self._format_int(text, 0, 1000, locale)
+            text = self._format_significant(abs(value),
+                                            self.int_prec[0],
+                                            self.int_prec[1])
+            a, sep, b = text.partition(".")
+            number = self._format_int(a, 0, 1000, locale)
+            if sep:
+                number += get_decimal_symbol(locale) + b
         else: # A normal number pattern
-            a, b = split_number(bankersround(abs(value), frac_prec[1]))
-            b = b or '0'
-            a = self._format_int(a, self.int_prec[0],
-                                 self.int_prec[1], locale)
-            b = self._format_frac(b, locale, force_frac)
-            number = a + b
+            precision = Decimal('1.' + '1' * frac_prec[1])
+            rounded = value.quantize(precision, ROUND_HALF_EVEN)
+            a, sep, b = str(abs(rounded)).partition(".")
+            number = (self._format_int(a, self.int_prec[0],
+                                       self.int_prec[1], locale) +
+                      self._format_frac(b or '0', locale, force_frac))
         retval = u'%s%s%s' % (self.prefix[is_negative], number,
                                 self.suffix[is_negative])
         if u'¤' in retval:
@@ -717,31 +619,44 @@ class NumberPattern(object):
             retval = retval.replace(u'¤', get_currency_symbol(currency, locale))
         return retval
 
-    def _format_sigdig(self, value, min, max):
-        """Convert value to a string.
-
-        The resulting string will contain between (min, max) number of
-        significant digits.
-        """
-        a, b = split_number(value)
-        ndecimals = len(a)
-        if a == '0' and b != '':
-            ndecimals = 0
-            while b.startswith('0'):
-                b = b[1:]
-                ndecimals -= 1
-        a, b = split_number(bankersround(value, max - ndecimals))
-        digits = len((a + b).lstrip('0'))
-        if not digits:
-            digits = 1
-        # Figure out if we need to add any trailing '0':s
-        if len(a) >= max and a != '0':
-            return a
-        if digits < min:
-            b += ('0' * (min - digits))
-        if b:
-            return '%s.%s' % (a, b)
-        return a
+    #
+    # This is one tricky piece of code.  The idea is to rely as much as possible
+    # on the decimal module to minimize the amount of code.
+    #
+    # Conceptually, the implementation of this method can be summarized in the
+    # following steps:
+    #
+    #   - Move or shift the decimal point (i.e. the exponent) so the maximum
+    #     amount of significant digits fall into the integer part (i.e. to the
+    #     left of the decimal point)
+    #
+    #   - Round the number to the nearest integer, discarding all the fractional
+    #     part which contained extra digits to be eliminated
+    #
+    #   - Convert the rounded integer to a string, that will contain the final
+    #     sequence of significant digits already trimmed to the maximum
+    #
+    #   - Restore the original position of the decimal point, potentially
+    #     padding with zeroes on either side
+    #
+    def _format_significant(self, value, minimum, maximum):
+        exp = value.adjusted()
+        scale = maximum - 1 - exp
+        digits = str(value.scaleb(scale).quantize(Decimal(1), ROUND_HALF_EVEN))
+        if scale <= 0:
+            result = digits + '0' * -scale
+        else:
+            intpart = digits[:-scale]
+            i = len(intpart)
+            j = i + max(minimum - i, 0)
+            result = "{intpart}.{pad:0<{fill}}{fracpart}{fracextra}".format(
+                    intpart=intpart or '0',
+                    pad='',
+                    fill=-min(exp + 1, 0),
+                    fracpart=digits[i:j],
+                    fracextra=digits[j:].rstrip('0'),
+            ).rstrip('.')
+        return result
 
     def _format_int(self, value, min, max, locale):
         width = len(value)
index a773f48f85dcf72930035d64c4b25886dfd4441b..fd3e7c815cc26add6de0ea578961baa1bb1abe46 100644 (file)
@@ -151,19 +151,6 @@ class FormatDecimalTestCase(unittest.TestCase):
         self.assertEqual('0.000000700', fmt)
 
 
-class BankersRoundTestCase(unittest.TestCase):
-    def test_round_to_nearest_integer(self):
-        self.assertEqual(1, numbers.bankersround(Decimal('0.5001')))
-
-    def test_round_to_even_for_two_nearest_integers(self):
-        self.assertEqual(0, numbers.bankersround(Decimal('0.5')))
-        self.assertEqual(2, numbers.bankersround(Decimal('1.5')))
-        self.assertEqual(-2, numbers.bankersround(Decimal('-2.5')))
-
-        self.assertEqual(0, numbers.bankersround(Decimal('0.05'), ndigits=1))
-        self.assertEqual(Decimal('0.2'), numbers.bankersround(Decimal('0.15'), ndigits=1))
-
-
 class NumberParsingTestCase(unittest.TestCase):
     def test_can_parse_decimals(self):
         self.assertEqual(Decimal('1099.98'),
@@ -320,13 +307,6 @@ def test_parse_decimal():
     assert excinfo.value.args[0] == "'2,109,998' is not a valid decimal number"
 
 
-def test_bankersround():
-    assert numbers.bankersround(5.5, 0) == 6.0
-    assert numbers.bankersround(6.5, 0) == 6.0
-    assert numbers.bankersround(-6.5, 0) == -6.0
-    assert numbers.bankersround(1234.0, -2) == 1200.0
-
-
 def test_parse_grouping():
     assert numbers.parse_grouping('##') == (1000, 1000)
     assert numbers.parse_grouping('#,###') == (3, 3)