]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
gh-127936, PEP 757: Convert marshal module to use import/export API for ints (#128530)
authorSergey B Kirpichev <skirpichev@gmail.com>
Thu, 23 Jan 2025 02:54:23 +0000 (05:54 +0300)
committerGitHub <noreply@github.com>
Thu, 23 Jan 2025 02:54:23 +0000 (02:54 +0000)
Co-authored-by: Victor Stinner <vstinner@python.org>
Co-authored-by: Bénédikt Tran <10796600+picnixz@users.noreply.github.com>
Python/marshal.c

index 72afa4ff89432c418e18dc8c9566f9b15febb529..cf7011652513ae89835986a5d3b6632b8fcfbf4f 100644 (file)
@@ -240,10 +240,6 @@ w_short_pstring(const void *s, Py_ssize_t n, WFILE *p)
 #define PyLong_MARSHAL_SHIFT 15
 #define PyLong_MARSHAL_BASE ((short)1 << PyLong_MARSHAL_SHIFT)
 #define PyLong_MARSHAL_MASK (PyLong_MARSHAL_BASE - 1)
-#if PyLong_SHIFT % PyLong_MARSHAL_SHIFT != 0
-#error "PyLong_SHIFT must be a multiple of PyLong_MARSHAL_SHIFT"
-#endif
-#define PyLong_MARSHAL_RATIO (PyLong_SHIFT / PyLong_MARSHAL_SHIFT)
 
 #define W_TYPE(t, p) do { \
     w_byte((t) | flag, (p)); \
@@ -252,47 +248,106 @@ w_short_pstring(const void *s, Py_ssize_t n, WFILE *p)
 static PyObject *
 _PyMarshal_WriteObjectToString(PyObject *x, int version, int allow_code);
 
+#define _r_digits(bitsize)                                                \
+static void                                                               \
+_r_digits##bitsize(const uint ## bitsize ## _t *digits, Py_ssize_t n,     \
+                   uint8_t negative, Py_ssize_t marshal_ratio, WFILE *p)  \
+{                                                                         \
+    /* set l to number of base PyLong_MARSHAL_BASE digits */              \
+    Py_ssize_t l = (n - 1)*marshal_ratio;                                 \
+    uint ## bitsize ## _t d = digits[n - 1];                              \
+                                                                          \
+    assert(marshal_ratio > 0);                                            \
+    assert(n >= 1);                                                       \
+    assert(d != 0); /* a PyLong is always normalized */                   \
+    do {                                                                  \
+        d >>= PyLong_MARSHAL_SHIFT;                                       \
+        l++;                                                              \
+    } while (d != 0);                                                     \
+    if (l > SIZE32_MAX) {                                                 \
+        p->depth--;                                                       \
+        p->error = WFERR_UNMARSHALLABLE;                                  \
+        return;                                                           \
+    }                                                                     \
+    w_long((long)(negative ? -l : l), p);                                 \
+                                                                          \
+    for (Py_ssize_t i = 0; i < n - 1; i++) {                              \
+        d = digits[i];                                                    \
+        for (Py_ssize_t j = 0; j < marshal_ratio; j++) {                  \
+            w_short(d & PyLong_MARSHAL_MASK, p);                          \
+            d >>= PyLong_MARSHAL_SHIFT;                                   \
+        }                                                                 \
+        assert(d == 0);                                                   \
+    }                                                                     \
+    d = digits[n - 1];                                                    \
+    do {                                                                  \
+        w_short(d & PyLong_MARSHAL_MASK, p);                              \
+        d >>= PyLong_MARSHAL_SHIFT;                                       \
+    } while (d != 0);                                                     \
+}
+_r_digits(16)
+_r_digits(32)
+#undef _r_digits
+
 static void
 w_PyLong(const PyLongObject *ob, char flag, WFILE *p)
 {
-    Py_ssize_t i, j, n, l;
-    digit d;
-
     W_TYPE(TYPE_LONG, p);
     if (_PyLong_IsZero(ob)) {
         w_long((long)0, p);
         return;
     }
 
-    /* set l to number of base PyLong_MARSHAL_BASE digits */
-    n = _PyLong_DigitCount(ob);
-    l = (n-1) * PyLong_MARSHAL_RATIO;
-    d = ob->long_value.ob_digit[n-1];
-    assert(d != 0); /* a PyLong is always normalized */
-    do {
-        d >>= PyLong_MARSHAL_SHIFT;
-        l++;
-    } while (d != 0);
-    if (l > SIZE32_MAX) {
+    PyLongExport long_export;
+
+    if (PyLong_Export((PyObject *)ob, &long_export) < 0) {
         p->depth--;
         p->error = WFERR_UNMARSHALLABLE;
         return;
     }
-    w_long((long)(_PyLong_IsNegative(ob) ? -l : l), p);
+    if (!long_export.digits) {
+        int8_t sign = long_export.value < 0 ? -1 : 1;
+        uint64_t abs_value = Py_ABS(long_export.value);
+        uint64_t d = abs_value;
+        long l = 0;
 
-    for (i=0; i < n-1; i++) {
-        d = ob->long_value.ob_digit[i];
-        for (j=0; j < PyLong_MARSHAL_RATIO; j++) {
+        /* set l to number of base PyLong_MARSHAL_BASE digits */
+        do {
+            d >>= PyLong_MARSHAL_SHIFT;
+            l += sign;
+        } while (d);
+        w_long(l, p);
+
+        d = abs_value;
+        do {
             w_short(d & PyLong_MARSHAL_MASK, p);
             d >>= PyLong_MARSHAL_SHIFT;
-        }
-        assert (d == 0);
+        } while (d);
+        return;
     }
-    d = ob->long_value.ob_digit[n-1];
-    do {
-        w_short(d & PyLong_MARSHAL_MASK, p);
-        d >>= PyLong_MARSHAL_SHIFT;
-    } while (d != 0);
+
+    const PyLongLayout *layout = PyLong_GetNativeLayout();
+    Py_ssize_t marshal_ratio = layout->bits_per_digit/PyLong_MARSHAL_SHIFT;
+
+    /* must be a multiple of PyLong_MARSHAL_SHIFT */
+    assert(layout->bits_per_digit % PyLong_MARSHAL_SHIFT == 0);
+    assert(layout->bits_per_digit >= PyLong_MARSHAL_SHIFT);
+
+    /* other assumptions on PyLongObject internals */
+    assert(layout->bits_per_digit <= 32);
+    assert(layout->digits_order == -1);
+    assert(layout->digit_endianness == (PY_LITTLE_ENDIAN ? -1 : 1));
+    assert(layout->digit_size == 2 || layout->digit_size == 4);
+
+    if (layout->digit_size == 4) {
+        _r_digits32(long_export.digits, long_export.ndigits,
+                    long_export.negative, marshal_ratio, p);
+    }
+    else {
+        _r_digits16(long_export.digits, long_export.ndigits,
+                    long_export.negative, marshal_ratio, p);
+    }
+    PyLong_FreeExport(&long_export);
 }
 
 static void
@@ -875,17 +930,62 @@ r_long64(RFILE *p)
                                  1 /* signed */);
 }
 
+#define _w_digits(bitsize)                                              \
+static int                                                              \
+_w_digits##bitsize(uint ## bitsize ## _t *digits, Py_ssize_t size,      \
+                   Py_ssize_t marshal_ratio,                            \
+                   int shorts_in_top_digit, RFILE *p)                   \
+{                                                                       \
+    uint ## bitsize ## _t d;                                            \
+                                                                        \
+    assert(size >= 1);                                                  \
+    for (Py_ssize_t i = 0; i < size - 1; i++) {                         \
+        d = 0;                                                          \
+        for (Py_ssize_t j = 0; j < marshal_ratio; j++) {                \
+            int md = r_short(p);                                        \
+            if (md < 0 || md > PyLong_MARSHAL_BASE) {                   \
+                goto bad_digit;                                         \
+            }                                                           \
+            d += (uint ## bitsize ## _t)md << j*PyLong_MARSHAL_SHIFT;   \
+        }                                                               \
+        digits[i] = d;                                                  \
+    }                                                                   \
+                                                                        \
+    d = 0;                                                              \
+    for (Py_ssize_t j = 0; j < shorts_in_top_digit; j++) {              \
+        int md = r_short(p);                                            \
+        if (md < 0 || md > PyLong_MARSHAL_BASE) {                       \
+            goto bad_digit;                                             \
+        }                                                               \
+        /* topmost marshal digit should be nonzero */                   \
+        if (md == 0 && j == shorts_in_top_digit - 1) {                  \
+            PyErr_SetString(PyExc_ValueError,                           \
+                "bad marshal data (unnormalized long data)");           \
+            return -1;                                                  \
+        }                                                               \
+        d += (uint ## bitsize ## _t)md << j*PyLong_MARSHAL_SHIFT;       \
+    }                                                                   \
+    assert(!PyErr_Occurred());                                          \
+    /* top digit should be nonzero, else the resulting PyLong won't be  \
+       normalized */                                                    \
+    digits[size - 1] = d;                                               \
+    return 0;                                                           \
+                                                                        \
+bad_digit:                                                              \
+    if (!PyErr_Occurred()) {                                            \
+        PyErr_SetString(PyExc_ValueError,                               \
+            "bad marshal data (digit out of range in long)");           \
+    }                                                                   \
+    return -1;                                                          \
+}
+_w_digits(32)
+_w_digits(16)
+#undef _w_digits
+
 static PyObject *
 r_PyLong(RFILE *p)
 {
-    PyLongObject *ob;
-    long n, size, i;
-    int j, md, shorts_in_top_digit;
-    digit d;
-
-    n = r_long(p);
-    if (n == 0)
-        return (PyObject *)_PyLong_New(0);
+    long n = r_long(p);
     if (n == -1 && PyErr_Occurred()) {
         return NULL;
     }
@@ -895,51 +995,44 @@ r_PyLong(RFILE *p)
         return NULL;
     }
 
-    size = 1 + (Py_ABS(n) - 1) / PyLong_MARSHAL_RATIO;
-    shorts_in_top_digit = 1 + (Py_ABS(n) - 1) % PyLong_MARSHAL_RATIO;
-    ob = _PyLong_New(size);
-    if (ob == NULL)
-        return NULL;
+    const PyLongLayout *layout = PyLong_GetNativeLayout();
+    Py_ssize_t marshal_ratio = layout->bits_per_digit/PyLong_MARSHAL_SHIFT;
 
-    _PyLong_SetSignAndDigitCount(ob, n < 0 ? -1 : 1, size);
+    /* must be a multiple of PyLong_MARSHAL_SHIFT */
+    assert(layout->bits_per_digit % PyLong_MARSHAL_SHIFT == 0);
+    assert(layout->bits_per_digit >= PyLong_MARSHAL_SHIFT);
 
-    for (i = 0; i < size-1; i++) {
-        d = 0;
-        for (j=0; j < PyLong_MARSHAL_RATIO; j++) {
-            md = r_short(p);
-            if (md < 0 || md > PyLong_MARSHAL_BASE)
-                goto bad_digit;
-            d += (digit)md << j*PyLong_MARSHAL_SHIFT;
-        }
-        ob->long_value.ob_digit[i] = d;
+    /* other assumptions on PyLongObject internals */
+    assert(layout->bits_per_digit <= 32);
+    assert(layout->digits_order == -1);
+    assert(layout->digit_endianness == (PY_LITTLE_ENDIAN ? -1 : 1));
+    assert(layout->digit_size == 2 || layout->digit_size == 4);
+
+    Py_ssize_t size = 1 + (Py_ABS(n) - 1) / marshal_ratio;
+
+    assert(size >= 1);
+
+    int shorts_in_top_digit = 1 + (Py_ABS(n) - 1) % marshal_ratio;
+    void *digits;
+    PyLongWriter *writer = PyLongWriter_Create(n < 0, size, &digits);
+
+    if (writer == NULL) {
+        return NULL;
     }
 
-    d = 0;
-    for (j=0; j < shorts_in_top_digit; j++) {
-        md = r_short(p);
-        if (md < 0 || md > PyLong_MARSHAL_BASE)
-            goto bad_digit;
-        /* topmost marshal digit should be nonzero */
-        if (md == 0 && j == shorts_in_top_digit - 1) {
-            Py_DECREF(ob);
-            PyErr_SetString(PyExc_ValueError,
-                "bad marshal data (unnormalized long data)");
-            return NULL;
-        }
-        d += (digit)md << j*PyLong_MARSHAL_SHIFT;
+    int ret;
+
+    if (layout->digit_size == 4) {
+        ret = _w_digits32(digits, size, marshal_ratio, shorts_in_top_digit, p);
     }
-    assert(!PyErr_Occurred());
-    /* top digit should be nonzero, else the resulting PyLong won't be
-       normalized */
-    ob->long_value.ob_digit[size-1] = d;
-    return (PyObject *)ob;
-  bad_digit:
-    Py_DECREF(ob);
-    if (!PyErr_Occurred()) {
-        PyErr_SetString(PyExc_ValueError,
-                        "bad marshal data (digit out of range in long)");
+    else {
+        ret = _w_digits16(digits, size, marshal_ratio, shorts_in_top_digit, p);
+    }
+    if (ret < 0) {
+        PyLongWriter_Discard(writer);
+        return NULL;
     }
-    return NULL;
+    return PyLongWriter_Finish(writer);
 }
 
 static double