]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
[3.12] gh-114563: C decimal falls back to pydecimal for unsupported format strings...
authorJohn Belmonte <john@neggie.net>
Mon, 12 Feb 2024 21:31:12 +0000 (06:31 +0900)
committerGitHub <noreply@github.com>
Mon, 12 Feb 2024 21:31:12 +0000 (23:31 +0200)
Immediate merits:
* eliminate complex workarounds for 'z' format support
  (NOTE: mpdecimal recently added 'z' support, so this becomes
  efficient in the long term.)
* fix 'z' format memory leak
* fix 'z' format applied to 'F'
* fix missing '#' format support

Suggested and prototyped by Stefan Krah.

Fixes gh-114563, gh-91060

(cherry picked from commit 72340d15cdfdfa4796fdd7c702094c852c2b32d2)

Co-authored-by: John Belmonte <john@neggie.net>
Co-authored-by: Stefan Krah <skrah@bytereef.org>
Lib/test/test_decimal.py
Misc/NEWS.d/next/Library/2024-02-11-20-23-36.gh-issue-114563.RzxNYT.rst [new file with mode: 0644]
Modules/_decimal/_decimal.c
Tools/c-analyzer/cpython/globals-to-fix.tsv

index 4d3ea732212905f47c7a2ccbb428022a28c17007..7ce61fc1edfd5d0c88e707198def710bcb25182a 100644 (file)
@@ -1121,6 +1121,13 @@ class FormatTest:
             ('z>z6.1f', '-0.', 'zzz0.0'),
             ('x>z6.1f', '-0.', 'xxx0.0'),
             ('🖤>z6.1f', '-0.', '🖤🖤🖤0.0'),  # multi-byte fill char
+            ('\x00>z6.1f', '-0.', '\x00\x00\x000.0'),  # null fill char
+
+            # issue 114563 ('z' format on F type in cdecimal)
+            ('z3,.10F', '-6.24E-323', '0.0000000000'),
+
+            # issue 91060 ('#' format in cdecimal)
+            ('#', '0', '0.'),
 
             # issue 6850
             ('a=-7.0', '0.12345', 'aaaa0.1'),
@@ -5712,6 +5719,21 @@ class CWhitebox(unittest.TestCase):
         with self.assertRaisesRegex(ValueError, err_msg):
             sd.copy()
 
+    def test_format_fallback_capitals(self):
+        # Fallback to _pydecimal formatting (triggered by `#` format which
+        # is unsupported by mpdecimal) should honor the current context.
+        x = C.Decimal('6.09e+23')
+        self.assertEqual(format(x, '#'), '6.09E+23')
+        with C.localcontext(capitals=0):
+            self.assertEqual(format(x, '#'), '6.09e+23')
+
+    def test_format_fallback_rounding(self):
+        y = C.Decimal('6.09')
+        self.assertEqual(format(y, '#.1f'), '6.1')
+        with C.localcontext(rounding=C.ROUND_DOWN):
+            self.assertEqual(format(y, '#.1f'), '6.0')
+
+
 @requires_docstrings
 @requires_cdecimal
 class SignatureTest(unittest.TestCase):
diff --git a/Misc/NEWS.d/next/Library/2024-02-11-20-23-36.gh-issue-114563.RzxNYT.rst b/Misc/NEWS.d/next/Library/2024-02-11-20-23-36.gh-issue-114563.RzxNYT.rst
new file mode 100644 (file)
index 0000000..013b6db
--- /dev/null
@@ -0,0 +1,4 @@
+Fix several :func:`format()` bugs when using the C implementation of :class:`~decimal.Decimal`:
+* memory leak in some rare cases when using the ``z`` format option (coerce negative 0)
+* incorrect output when applying the ``z`` format option to type ``F`` (fixed-point with capital ``NAN`` / ``INF``)
+* incorrect output when applying the ``#`` format option (alternate form)
index 70b13982bb0cc4c09809b2f7fbcefec2f26e4f3e..1a195816fe526940e9bc80d9a9e77a1ebe572c84 100644 (file)
@@ -143,6 +143,8 @@ static PyObject *default_context_template = NULL;
 static PyObject *basic_context_template = NULL;
 static PyObject *extended_context_template = NULL;
 
+/* Invariant: NULL or pointer to _pydecimal.Decimal */
+static PyObject *PyDecimal = NULL;
 
 /* Error codes for functions that return signals or conditions */
 #define DEC_INVALID_SIGNALS (MPD_Max_status+1U)
@@ -3219,56 +3221,6 @@ dotsep_as_utf8(const char *s)
     return utf8;
 }
 
-/* copy of libmpdec _mpd_round() */
-static void
-_mpd_round(mpd_t *result, const mpd_t *a, mpd_ssize_t prec,
-           const mpd_context_t *ctx, uint32_t *status)
-{
-    mpd_ssize_t exp = a->exp + a->digits - prec;
-
-    if (prec <= 0) {
-        mpd_seterror(result, MPD_Invalid_operation, status);
-        return;
-    }
-    if (mpd_isspecial(a) || mpd_iszero(a)) {
-        mpd_qcopy(result, a, status);
-        return;
-    }
-
-    mpd_qrescale_fmt(result, a, exp, ctx, status);
-    if (result->digits > prec) {
-        mpd_qrescale_fmt(result, result, exp+1, ctx, status);
-    }
-}
-
-/* Locate negative zero "z" option within a UTF-8 format spec string.
- * Returns pointer to "z", else NULL.
- * The portion of the spec we're working with is [[fill]align][sign][z] */
-static const char *
-format_spec_z_search(char const *fmt, Py_ssize_t size) {
-    char const *pos = fmt;
-    char const *fmt_end = fmt + size;
-    /* skip over [[fill]align] (fill may be multi-byte character) */
-    pos += 1;
-    while (pos < fmt_end && *pos & 0x80) {
-        pos += 1;
-    }
-    if (pos < fmt_end && strchr("<>=^", *pos) != NULL) {
-        pos += 1;
-    } else {
-        /* fill not present-- skip over [align] */
-        pos = fmt;
-        if (pos < fmt_end && strchr("<>=^", *pos) != NULL) {
-            pos += 1;
-        }
-    }
-    /* skip over [sign] */
-    if (pos < fmt_end && strchr("+- ", *pos) != NULL) {
-        pos += 1;
-    }
-    return pos < fmt_end && *pos == 'z' ? pos : NULL;
-}
-
 static int
 dict_get_item_string(PyObject *dict, const char *key, PyObject **valueobj, const char **valuestr)
 {
@@ -3294,6 +3246,48 @@ dict_get_item_string(PyObject *dict, const char *key, PyObject **valueobj, const
     return 0;
 }
 
+/*
+ * Fallback _pydecimal formatting for new format specifiers that mpdecimal does
+ * not yet support. As documented, libmpdec follows the PEP-3101 format language:
+ * https://www.bytereef.org/mpdecimal/doc/libmpdec/assign-convert.html#to-string
+ */
+static PyObject *
+pydec_format(PyObject *dec, PyObject *context, PyObject *fmt)
+{
+    PyObject *result;
+    PyObject *pydec;
+    PyObject *u;
+
+    if (PyDecimal == NULL) {
+        PyDecimal = _PyImport_GetModuleAttrString("_pydecimal", "Decimal");
+        if (PyDecimal == NULL) {
+            return NULL;
+        }
+    }
+
+    u = dec_str(dec);
+    if (u == NULL) {
+        return NULL;
+    }
+
+    pydec = PyObject_CallOneArg(PyDecimal, u);
+    Py_DECREF(u);
+    if (pydec == NULL) {
+        return NULL;
+    }
+
+    result = PyObject_CallMethod(pydec, "__format__", "(OO)", fmt, context);
+    Py_DECREF(pydec);
+
+    if (result == NULL && PyErr_ExceptionMatches(PyExc_ValueError)) {
+        /* Do not confuse users with the _pydecimal exception */
+        PyErr_Clear();
+        PyErr_SetString(PyExc_ValueError, "invalid format string");
+    }
+
+    return result;
+}
+
 /* Formatted representation of a PyDecObject. */
 static PyObject *
 dec_format(PyObject *dec, PyObject *args)
@@ -3306,16 +3300,11 @@ dec_format(PyObject *dec, PyObject *args)
     PyObject *fmtarg;
     PyObject *context;
     mpd_spec_t spec;
-    char const *fmt;
-    char *fmt_copy = NULL;
+    char *fmt;
     char *decstring = NULL;
     uint32_t status = 0;
     int replace_fillchar = 0;
-    int no_neg_0 = 0;
     Py_ssize_t size;
-    mpd_t *mpd = MPD(dec);
-    mpd_uint_t dt[MPD_MINALLOC_MAX];
-    mpd_t tmp = {MPD_STATIC|MPD_STATIC_DATA,0,0,0,MPD_MINALLOC_MAX,dt};
 
 
     CURRENT_CONTEXT(context);
@@ -3324,39 +3313,20 @@ dec_format(PyObject *dec, PyObject *args)
     }
 
     if (PyUnicode_Check(fmtarg)) {
-        fmt = PyUnicode_AsUTF8AndSize(fmtarg, &size);
+        fmt = (char *)PyUnicode_AsUTF8AndSize(fmtarg, &size);
         if (fmt == NULL) {
             return NULL;
         }
-        /* NOTE: If https://github.com/python/cpython/pull/29438 lands, the
-         *   format string manipulation below can be eliminated by enhancing
-         *   the forked mpd_parse_fmt_str(). */
+
         if (size > 0 && fmt[0] == '\0') {
             /* NUL fill character: must be replaced with a valid UTF-8 char
                before calling mpd_parse_fmt_str(). */
             replace_fillchar = 1;
-            fmt = fmt_copy = dec_strdup(fmt, size);
-            if (fmt_copy == NULL) {
+            fmt = dec_strdup(fmt, size);
+            if (fmt == NULL) {
                 return NULL;
             }
-            fmt_copy[0] = '_';
-        }
-        /* Strip 'z' option, which isn't understood by mpd_parse_fmt_str().
-         * NOTE: fmt is always null terminated by PyUnicode_AsUTF8AndSize() */
-        char const *z_position = format_spec_z_search(fmt, size);
-        if (z_position != NULL) {
-            no_neg_0 = 1;
-            size_t z_index = z_position - fmt;
-            if (fmt_copy == NULL) {
-                fmt = fmt_copy = dec_strdup(fmt, size);
-                if (fmt_copy == NULL) {
-                    return NULL;
-                }
-            }
-            /* Shift characters (including null terminator) left,
-               overwriting the 'z' option. */
-            memmove(fmt_copy + z_index, fmt_copy + z_index + 1, size - z_index);
-            size -= 1;
+            fmt[0] = '_';
         }
     }
     else {
@@ -3366,10 +3336,13 @@ dec_format(PyObject *dec, PyObject *args)
     }
 
     if (!mpd_parse_fmt_str(&spec, fmt, CtxCaps(context))) {
-        PyErr_SetString(PyExc_ValueError,
-            "invalid format string");
-        goto finish;
+        if (replace_fillchar) {
+            PyMem_Free(fmt);
+        }
+
+        return pydec_format(dec, context, fmtarg);
     }
+
     if (replace_fillchar) {
         /* In order to avoid clobbering parts of UTF-8 thousands separators or
            decimal points when the substitution is reversed later, the actual
@@ -3422,45 +3395,8 @@ dec_format(PyObject *dec, PyObject *args)
         }
     }
 
-    if (no_neg_0 && mpd_isnegative(mpd) && !mpd_isspecial(mpd)) {
-        /* Round into a temporary (carefully mirroring the rounding
-           of mpd_qformat_spec()), and check if the result is negative zero.
-           If so, clear the sign and format the resulting positive zero. */
-        mpd_ssize_t prec;
-        mpd_qcopy(&tmp, mpd, &status);
-        if (spec.prec >= 0) {
-            switch (spec.type) {
-              case 'f':
-                  mpd_qrescale(&tmp, &tmp, -spec.prec, CTX(context), &status);
-                  break;
-              case '%':
-                  tmp.exp += 2;
-                  mpd_qrescale(&tmp, &tmp, -spec.prec, CTX(context), &status);
-                  break;
-              case 'g':
-                  prec = (spec.prec == 0) ? 1 : spec.prec;
-                  if (tmp.digits > prec) {
-                      _mpd_round(&tmp, &tmp, prec, CTX(context), &status);
-                  }
-                  break;
-              case 'e':
-                  if (!mpd_iszero(&tmp)) {
-                      _mpd_round(&tmp, &tmp, spec.prec+1, CTX(context), &status);
-                  }
-                  break;
-            }
-        }
-        if (status & MPD_Errors) {
-            PyErr_SetString(PyExc_ValueError, "unexpected error when rounding");
-            goto finish;
-        }
-        if (mpd_iszero(&tmp)) {
-            mpd_set_positive(&tmp);
-            mpd = &tmp;
-        }
-    }
 
-    decstring = mpd_qformat_spec(mpd, &spec, CTX(context), &status);
+    decstring = mpd_qformat_spec(MPD(dec), &spec, CTX(context), &status);
     if (decstring == NULL) {
         if (status & MPD_Malloc_error) {
             PyErr_NoMemory();
@@ -3483,7 +3419,7 @@ finish:
     Py_XDECREF(grouping);
     Py_XDECREF(sep);
     Py_XDECREF(dot);
-    if (fmt_copy) PyMem_Free(fmt_copy);
+    if (replace_fillchar) PyMem_Free(fmt);
     if (decstring) mpd_free(decstring);
     return result;
 }
@@ -5893,6 +5829,9 @@ PyInit__decimal(void)
     /* Create the module */
     ASSIGN_PTR(m, PyModule_Create(&_decimal_module));
 
+    /* For format specifiers not yet supported by libmpdec */
+    PyDecimal = NULL;
+
     /* Add types to the module */
     CHECK_INT(PyModule_AddObjectRef(m, "Decimal", (PyObject *)&PyDec_Type));
     CHECK_INT(PyModule_AddObjectRef(m, "Context", (PyObject *)&PyDecContext_Type));
index 8bfafd6e72606304928d1057d47b00272b2633a9..b47393e6fdd36a6821c868ba716b73c1ff75408c 100644 (file)
@@ -436,6 +436,7 @@ Modules/_decimal/_decimal.c -       basic_context_template  -
 Modules/_decimal/_decimal.c    -       current_context_var     -
 Modules/_decimal/_decimal.c    -       default_context_template        -
 Modules/_decimal/_decimal.c    -       extended_context_template       -
+Modules/_decimal/_decimal.c    -       PyDecimal       -
 Modules/_decimal/_decimal.c    -       round_map       -
 Modules/_decimal/_decimal.c    -       Rational        -
 Modules/_decimal/_decimal.c    -       SignalTuple     -