]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
gh-130230: Fix crash in pow() with only Decimal third argument (GH-130237)
authorSerhiy Storchaka <storchaka@gmail.com>
Tue, 18 Feb 2025 07:46:48 +0000 (09:46 +0200)
committerGitHub <noreply@github.com>
Tue, 18 Feb 2025 07:46:48 +0000 (09:46 +0200)
Lib/test/test_decimal.py
Misc/NEWS.d/next/Library/2025-02-17-21-16-51.gh-issue-130230.9ta9P9.rst [new file with mode: 0644]
Modules/_decimal/_decimal.c

index 02d3fa985e75b9f9ec7ca9de44e36ee9dba76698..0e15eb3693888dd8fd07ae02e1c108398ba6cc27 100644 (file)
@@ -4481,6 +4481,15 @@ class Coverage:
             self.assertIs(Decimal("NaN").fma(7, 1).is_nan(), True)
             # three arg power
             self.assertEqual(pow(Decimal(10), 2, 7), 2)
+            if self.decimal == C:
+                self.assertEqual(pow(10, Decimal(2), 7), 2)
+                self.assertEqual(pow(10, 2, Decimal(7)), 2)
+            else:
+                # XXX: Three-arg power doesn't use __rpow__.
+                self.assertRaises(TypeError, pow, 10, Decimal(2), 7)
+                # XXX: There is no special method to dispatch on the
+                # third arg of three-arg power.
+                self.assertRaises(TypeError, pow, 10, 2, Decimal(7))
             # exp
             self.assertEqual(Decimal("1.01").exp(), 3)
             # is_normal
diff --git a/Misc/NEWS.d/next/Library/2025-02-17-21-16-51.gh-issue-130230.9ta9P9.rst b/Misc/NEWS.d/next/Library/2025-02-17-21-16-51.gh-issue-130230.9ta9P9.rst
new file mode 100644 (file)
index 0000000..20327fd
--- /dev/null
@@ -0,0 +1 @@
+Fix crash in :func:`pow` with only :class:`~decimal.Decimal` third argument.
index 3dcb3e9870c8a44ac6db0671311c7c192163d4b3..8a24d8c12cab8adfdc26dd56fd2cae95025d401f 100644 (file)
@@ -147,6 +147,24 @@ find_state_left_or_right(PyObject *left, PyObject *right)
     return (decimal_state *)state;
 }
 
+static inline decimal_state *
+find_state_ternary(PyObject *left, PyObject *right, PyObject *modulus)
+{
+    PyTypeObject *base;
+    if (PyType_GetBaseByToken(Py_TYPE(left), &dec_spec, &base) != 1) {
+        assert(!PyErr_Occurred());
+        if (PyType_GetBaseByToken(Py_TYPE(right), &dec_spec, &base) != 1) {
+            assert(!PyErr_Occurred());
+            PyType_GetBaseByToken(Py_TYPE(modulus), &dec_spec, &base);
+        }
+    }
+    assert(base != NULL);
+    void *state = _PyType_GetModuleState(base);
+    assert(state != NULL);
+    Py_DECREF(base);
+    return (decimal_state *)state;
+}
+
 
 #if !defined(MPD_VERSION_HEX) || MPD_VERSION_HEX < 0x02050000
   #error "libmpdec version >= 2.5.0 required"
@@ -4407,7 +4425,7 @@ nm_mpd_qpow(PyObject *base, PyObject *exp, PyObject *mod)
     PyObject *context;
     uint32_t status = 0;
 
-    decimal_state *state = find_state_left_or_right(base, exp);
+    decimal_state *state = find_state_ternary(base, exp, mod);
     CURRENT_CONTEXT(state, context);
     CONVERT_BINOP(&a, &b, base, exp, context);