]> git.ipfire.org Git - thirdparty/babel.git/commitdiff
add a `strict` mode to `parse_decimal()`
authorCharly C <changaco@changaco.oy.lc>
Sat, 16 Jun 2018 09:09:13 +0000 (11:09 +0200)
committerChangaco <changaco@changaco.oy.lc>
Sun, 17 Jun 2018 07:10:32 +0000 (09:10 +0200)
Fixes https://github.com/python-babel/babel/issues/589

babel/numbers.py
tests/test_numbers.py

index 518d4945e2a251f278877dcdf6505fa7b69ff8c3..509150cb9cb230f4ab456c66a3e9993794457a36 100644 (file)
@@ -661,7 +661,7 @@ def parse_number(string, locale=LC_NUMERIC):
         raise NumberFormatError('%r is not a valid number' % string)
 
 
-def parse_decimal(string, locale=LC_NUMERIC):
+def parse_decimal(string, locale=LC_NUMERIC, strict=False):
     """Parse localized decimal string into a decimal.
 
     >>> parse_decimal('1,099.98', locale='en_US')
@@ -676,17 +676,36 @@ def parse_decimal(string, locale=LC_NUMERIC):
         ...
     NumberFormatError: '2,109,998' is not a valid decimal number
 
+    If `strict` is set to `True` and the given string contains a number
+    formatted in an irregular way, an exception is raised:
+
+    >>> parse_decimal('30.00', locale='de', strict=True)
+    Traceback (most recent call last):
+        ...
+    NumberFormatError: '30.00' is not a properly formatted decimal number
+
     :param string: the string to parse
     :param locale: the `Locale` object or locale identifier
+    :param strict: controls whether numbers formatted in a weird way are
+                   accepted or rejected
     :raise NumberFormatError: if the string can not be converted to a
                               decimal number
     """
     locale = Locale.parse(locale)
+    group_symbol = get_group_symbol(locale)
+    decimal_symbol = get_decimal_symbol(locale)
     try:
-        return decimal.Decimal(string.replace(get_group_symbol(locale), '')
-                               .replace(get_decimal_symbol(locale), '.'))
+        parsed = decimal.Decimal(string.replace(group_symbol, '')
+                                       .replace(decimal_symbol, '.'))
     except decimal.InvalidOperation:
         raise NumberFormatError('%r is not a valid decimal number' % string)
+    if strict and group_symbol in string:
+        proper = format_decimal(parsed, locale=locale, decimal_quantization=False)
+        if string != proper and string.rstrip('0') != (proper + decimal_symbol):
+            raise NumberFormatError(
+                "%r is not a properly formatted decimal number" % string
+            )
+    return parsed
 
 
 PREFIX_END = r'[^0-9@#.,]'
index 32f4280e76eea6b7a5723cc5ecae199f7d985a62..50c53dec4470308d96ca8ad66d3584a457821520 100644 (file)
@@ -165,6 +165,25 @@ class NumberParsingTestCase(unittest.TestCase):
         self.assertRaises(numbers.NumberFormatError,
                           lambda: numbers.parse_decimal('2,109,998', locale='de'))
 
+    def test_parse_decimal_strict_mode(self):
+        # Numbers with a misplaced grouping symbol should be rejected
+        with self.assertRaises(numbers.NumberFormatError):
+            numbers.parse_decimal('11.11', locale='de', strict=True)
+        # Partially grouped numbers should be rejected
+        with self.assertRaises(numbers.NumberFormatError):
+            numbers.parse_decimal('2000,000', locale='en_US', strict=True)
+        # Numbers with duplicate grouping symbols should be rejected
+        with self.assertRaises(numbers.NumberFormatError):
+            numbers.parse_decimal('0,,000', locale='en_US', strict=True)
+        # Properly formatted numbers should be accepted
+        assert str(numbers.parse_decimal('1.001', locale='de', strict=True)) == '1001'
+        # Trailing zeroes should be accepted
+        assert str(numbers.parse_decimal('3.00', locale='en_US', strict=True)) == '3.00'
+        # Numbers without any grouping symbol should be accepted
+        assert str(numbers.parse_decimal('2000.1', locale='en_US', strict=True)) == '2000.1'
+        # High precision numbers should be accepted
+        assert str(numbers.parse_decimal('5,000001', locale='fr', strict=True)) == '5.000001'
+
 
 def test_list_currencies():
     assert isinstance(list_currencies(), set)