]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
gh-124688: _decimal: Get module state from ctx for performance (#124691)
authorneonene <53406459+neonene@users.noreply.github.com>
Sat, 28 Sep 2024 16:12:53 +0000 (01:12 +0900)
committerGitHub <noreply@github.com>
Sat, 28 Sep 2024 16:12:53 +0000 (16:12 +0000)
Get a module state from ctx objects for performance.

Modules/_decimal/_decimal.c

index 68d1da9faab86724bd5a1d0c19c667c70dec3922..a33c9793b5ad17d7e46588a65d5674feac4626ca 100644 (file)
@@ -123,6 +123,7 @@ get_module_state(PyObject *mod)
 
 static struct PyModuleDef _decimal_module;
 static PyType_Spec dec_spec;
+static PyType_Spec context_spec;
 
 static inline decimal_state *
 get_module_state_by_def(PyTypeObject *tp)
@@ -190,6 +191,7 @@ typedef struct PyDecContextObject {
     PyObject *flags;
     int capitals;
     PyThreadState *tstate;
+    decimal_state *modstate;
 } PyDecContextObject;
 
 typedef struct {
@@ -210,6 +212,15 @@ typedef struct {
 #define CTX(v) (&((PyDecContextObject *)v)->ctx)
 #define CtxCaps(v) (((PyDecContextObject *)v)->capitals)
 
+static inline decimal_state *
+get_module_state_from_ctx(PyObject *v)
+{
+    assert(PyType_GetBaseByToken(Py_TYPE(v), &context_spec, NULL) == 1);
+    decimal_state *state = ((PyDecContextObject *)v)->modstate;
+    assert(state != NULL);
+    return state;
+}
+
 
 Py_LOCAL_INLINE(PyObject *)
 incr_true(void)
@@ -564,7 +575,7 @@ static int
 dec_addstatus(PyObject *context, uint32_t status)
 {
     mpd_context_t *ctx = CTX(context);
-    decimal_state *state = get_module_state_by_def(Py_TYPE(context));
+    decimal_state *state = get_module_state_from_ctx(context);
 
     ctx->status |= status;
     if (status & (ctx->traps|MPD_Malloc_error)) {
@@ -859,7 +870,7 @@ static PyObject *
 context_getround(PyObject *self, void *Py_UNUSED(closure))
 {
     int i = mpd_getround(CTX(self));
-    decimal_state *state = get_module_state_by_def(Py_TYPE(self));
+    decimal_state *state = get_module_state_from_ctx(self);
 
     return Py_NewRef(state->round_map[i]);
 }
@@ -1018,7 +1029,7 @@ context_setround(PyObject *self, PyObject *value, void *Py_UNUSED(closure))
     mpd_context_t *ctx;
     int x;
 
-    decimal_state *state = get_module_state_by_def(Py_TYPE(self));
+    decimal_state *state = get_module_state_from_ctx(self);
     x = getround(state, value);
     if (x == -1) {
         return -1;
@@ -1077,7 +1088,7 @@ context_settraps_list(PyObject *self, PyObject *value)
 {
     mpd_context_t *ctx;
     uint32_t flags;
-    decimal_state *state = get_module_state_by_def(Py_TYPE(self));
+    decimal_state *state = get_module_state_from_ctx(self);
     flags = list_as_flags(state, value);
     if (flags & DEC_ERRORS) {
         return -1;
@@ -1097,7 +1108,7 @@ context_settraps_dict(PyObject *self, PyObject *value)
     mpd_context_t *ctx;
     uint32_t flags;
 
-    decimal_state *state = get_module_state_by_def(Py_TYPE(self));
+    decimal_state *state = get_module_state_from_ctx(self);
     if (PyDecSignalDict_Check(state, value)) {
         flags = SdFlags(value);
     }
@@ -1142,7 +1153,7 @@ context_setstatus_list(PyObject *self, PyObject *value)
 {
     mpd_context_t *ctx;
     uint32_t flags;
-    decimal_state *state = get_module_state_by_def(Py_TYPE(self));
+    decimal_state *state = get_module_state_from_ctx(self);
 
     flags = list_as_flags(state, value);
     if (flags & DEC_ERRORS) {
@@ -1163,7 +1174,7 @@ context_setstatus_dict(PyObject *self, PyObject *value)
     mpd_context_t *ctx;
     uint32_t flags;
 
-    decimal_state *state = get_module_state_by_def(Py_TYPE(self));
+    decimal_state *state = get_module_state_from_ctx(self);
     if (PyDecSignalDict_Check(state, value)) {
         flags = SdFlags(value);
     }
@@ -1393,6 +1404,7 @@ context_new(PyTypeObject *type,
 
     CtxCaps(self) = 1;
     self->tstate = NULL;
+    self->modstate = state;
 
     if (type == state->PyDecContext_Type) {
         PyObject_GC_Track(self);
@@ -1470,7 +1482,7 @@ context_repr(PyDecContextObject *self)
     int n, mem;
 
 #ifdef Py_DEBUG
-    decimal_state *state = get_module_state_by_def(Py_TYPE(self));
+    decimal_state *state = get_module_state_from_ctx((PyObject *)self);
     assert(PyDecContext_Check(state, self));
 #endif
     ctx = CTX(self);
@@ -1561,7 +1573,7 @@ context_copy(PyObject *self, PyObject *Py_UNUSED(dummy))
 {
     PyObject *copy;
 
-    decimal_state *state = get_module_state_by_def(Py_TYPE(self));
+    decimal_state *state = get_module_state_from_ctx(self);
     copy = PyObject_CallObject((PyObject *)state->PyDecContext_Type, NULL);
     if (copy == NULL) {
         return NULL;
@@ -1581,7 +1593,7 @@ context_reduce(PyObject *self, PyObject *Py_UNUSED(dummy))
     PyObject *traps;
     PyObject *ret;
     mpd_context_t *ctx;
-    decimal_state *state = get_module_state_by_def(Py_TYPE(self));
+    decimal_state *state = get_module_state_from_ctx(self);
 
     ctx = CTX(self);
 
@@ -2022,11 +2034,10 @@ static PyType_Spec ctxmanager_spec = {
 /******************************************************************************/
 
 static PyObject *
-PyDecType_New(PyTypeObject *type)
+PyDecType_New(decimal_state *state, PyTypeObject *type)
 {
     PyDecObject *dec;
 
-    decimal_state *state = get_module_state_by_def(type);
     if (type == state->PyDec_Type) {
         dec = PyObject_GC_New(PyDecObject, state->PyDec_Type);
     }
@@ -2052,7 +2063,7 @@ PyDecType_New(PyTypeObject *type)
     assert(PyObject_GC_IsTracked((PyObject *)dec));
     return (PyObject *)dec;
 }
-#define dec_alloc(st) PyDecType_New((st)->PyDec_Type)
+#define dec_alloc(st) PyDecType_New(st, (st)->PyDec_Type)
 
 static int
 dec_traverse(PyObject *dec, visitproc visit, void *arg)
@@ -2155,7 +2166,8 @@ PyDecType_FromCString(PyTypeObject *type, const char *s,
     PyObject *dec;
     uint32_t status = 0;
 
-    dec = PyDecType_New(type);
+    decimal_state *state = get_module_state_from_ctx(context);
+    dec = PyDecType_New(state, type);
     if (dec == NULL) {
         return NULL;
     }
@@ -2179,7 +2191,8 @@ PyDecType_FromCStringExact(PyTypeObject *type, const char *s,
     uint32_t status = 0;
     mpd_context_t maxctx;
 
-    dec = PyDecType_New(type);
+    decimal_state *state = get_module_state_from_ctx(context);
+    dec = PyDecType_New(state, type);
     if (dec == NULL) {
         return NULL;
     }
@@ -2266,7 +2279,8 @@ PyDecType_FromSsize(PyTypeObject *type, mpd_ssize_t v, PyObject *context)
     PyObject *dec;
     uint32_t status = 0;
 
-    dec = PyDecType_New(type);
+    decimal_state *state = get_module_state_from_ctx(context);
+    dec = PyDecType_New(state, type);
     if (dec == NULL) {
         return NULL;
     }
@@ -2287,7 +2301,8 @@ PyDecType_FromSsizeExact(PyTypeObject *type, mpd_ssize_t v, PyObject *context)
     uint32_t status = 0;
     mpd_context_t maxctx;
 
-    dec = PyDecType_New(type);
+    decimal_state *state = get_module_state_from_ctx(context);
+    dec = PyDecType_New(state, type);
     if (dec == NULL) {
         return NULL;
     }
@@ -2305,13 +2320,13 @@ PyDecType_FromSsizeExact(PyTypeObject *type, mpd_ssize_t v, PyObject *context)
 /* Convert from a PyLongObject. The context is not modified; flags set
    during conversion are accumulated in the status parameter. */
 static PyObject *
-dec_from_long(PyTypeObject *type, PyObject *v,
+dec_from_long(decimal_state *state, PyTypeObject *type, PyObject *v,
               const mpd_context_t *ctx, uint32_t *status)
 {
     PyObject *dec;
     PyLongObject *l = (PyLongObject *)v;
 
-    dec = PyDecType_New(type);
+    dec = PyDecType_New(state, type);
     if (dec == NULL) {
         return NULL;
     }
@@ -2356,7 +2371,8 @@ PyDecType_FromLong(PyTypeObject *type, PyObject *v, PyObject *context)
         return NULL;
     }
 
-    dec = dec_from_long(type, v, CTX(context), &status);
+    decimal_state *state = get_module_state_from_ctx(context);
+    dec = dec_from_long(state, type, v, CTX(context), &status);
     if (dec == NULL) {
         return NULL;
     }
@@ -2385,7 +2401,8 @@ PyDecType_FromLongExact(PyTypeObject *type, PyObject *v,
     }
 
     mpd_maxcontext(&maxctx);
-    dec = dec_from_long(type, v, &maxctx, &status);
+    decimal_state *state = get_module_state_from_ctx(context);
+    dec = dec_from_long(state, type, v, &maxctx, &status);
     if (dec == NULL) {
         return NULL;
     }
@@ -2417,7 +2434,7 @@ PyDecType_FromFloatExact(PyTypeObject *type, PyObject *v,
     mpd_t *d1, *d2;
     uint32_t status = 0;
     mpd_context_t maxctx;
-    decimal_state *state = get_module_state_by_def(type);
+    decimal_state *state = get_module_state_from_ctx(context);
 
 #ifdef Py_DEBUG
     assert(PyType_IsSubtype(type, state->PyDec_Type));
@@ -2438,7 +2455,7 @@ PyDecType_FromFloatExact(PyTypeObject *type, PyObject *v,
     sign = (copysign(1.0, x) == 1.0) ? 0 : 1;
 
     if (isnan(x) || isinf(x)) {
-        dec = PyDecType_New(type);
+        dec = PyDecType_New(state, type);
         if (dec == NULL) {
             return NULL;
         }
@@ -2555,12 +2572,12 @@ PyDecType_FromDecimalExact(PyTypeObject *type, PyObject *v, PyObject *context)
     PyObject *dec;
     uint32_t status = 0;
 
-    decimal_state *state = get_module_state_by_def(type);
+    decimal_state *state = get_module_state_from_ctx(context);
     if (type == state->PyDec_Type && PyDec_CheckExact(state, v)) {
         return Py_NewRef(v);
     }
 
-    dec = PyDecType_New(type);
+    dec = PyDecType_New(state, type);
     if (dec == NULL) {
         return NULL;
     }
@@ -2844,7 +2861,7 @@ dec_from_float(PyObject *type, PyObject *pyfloat)
 static PyObject *
 ctx_from_float(PyObject *context, PyObject *v)
 {
-    decimal_state *state = get_module_state_by_def(Py_TYPE(context));
+    decimal_state *state = get_module_state_from_ctx(context);
     return PyDec_FromFloat(state, v, context);
 }
 
@@ -2855,7 +2872,7 @@ dec_apply(PyObject *v, PyObject *context)
     PyObject *result;
     uint32_t status = 0;
 
-    decimal_state *state = get_module_state_by_def(Py_TYPE(context));
+    decimal_state *state = get_module_state_from_ctx(context);
     result = dec_alloc(state);
     if (result == NULL) {
         return NULL;
@@ -2882,7 +2899,7 @@ dec_apply(PyObject *v, PyObject *context)
 static PyObject *
 PyDecType_FromObjectExact(PyTypeObject *type, PyObject *v, PyObject *context)
 {
-    decimal_state *state = get_module_state_by_def(type);
+    decimal_state *state = get_module_state_from_ctx(context);
     if (v == NULL) {
         return PyDecType_FromSsizeExact(type, 0, context);
     }
@@ -2917,7 +2934,7 @@ PyDecType_FromObjectExact(PyTypeObject *type, PyObject *v, PyObject *context)
 static PyObject *
 PyDec_FromObject(PyObject *v, PyObject *context)
 {
-    decimal_state *state = get_module_state_by_def(Py_TYPE(context));
+    decimal_state *state = get_module_state_from_ctx(context);
     if (v == NULL) {
         return PyDec_FromSsize(state, 0, context);
     }
@@ -3004,7 +3021,7 @@ ctx_create_decimal(PyObject *context, PyObject *args)
 Py_LOCAL_INLINE(int)
 convert_op(int type_err, PyObject **conv, PyObject *v, PyObject *context)
 {
-    decimal_state *state = get_module_state_by_def(Py_TYPE(context));
+    decimal_state *state = get_module_state_from_ctx(context);
     if (PyDec_Check(state, v)) {
         *conv = Py_NewRef(v);
         return 1;
@@ -3107,7 +3124,7 @@ multiply_by_denominator(PyObject *v, PyObject *r, PyObject *context)
     if (tmp == NULL) {
         return NULL;
     }
-    decimal_state *state = get_module_state_by_def(Py_TYPE(context));
+    decimal_state *state = get_module_state_from_ctx(context);
     denom = PyDec_FromLongExact(state, tmp, context);
     Py_DECREF(tmp);
     if (denom == NULL) {
@@ -3162,7 +3179,7 @@ numerator_as_decimal(PyObject *r, PyObject *context)
         return NULL;
     }
 
-    decimal_state *state = get_module_state_by_def(Py_TYPE(context));
+    decimal_state *state = get_module_state_from_ctx(context);
     num = PyDec_FromLongExact(state, tmp, context);
     Py_DECREF(tmp);
     return num;
@@ -3181,7 +3198,7 @@ convert_op_cmp(PyObject **vcmp, PyObject **wcmp, PyObject *v, PyObject *w,
 
     *vcmp = v;
 
-    decimal_state *state = get_module_state_by_def(Py_TYPE(context));
+    decimal_state *state = get_module_state_from_ctx(context);
     if (PyDec_Check(state, w)) {
         *wcmp = Py_NewRef(w);
     }
@@ -4421,12 +4438,11 @@ dec_conjugate(PyObject *self, PyObject *Py_UNUSED(dummy))
     return Py_NewRef(self);
 }
 
-static PyObject *
-dec_mpd_radix(PyObject *self, PyObject *Py_UNUSED(dummy))
+static inline PyObject *
+_dec_mpd_radix(decimal_state *state)
 {
     PyObject *result;
 
-    decimal_state *state = get_module_state_by_def(Py_TYPE(self));
     result = dec_alloc(state);
     if (result == NULL) {
         return NULL;
@@ -4436,6 +4452,13 @@ dec_mpd_radix(PyObject *self, PyObject *Py_UNUSED(dummy))
     return result;
 }
 
+static PyObject *
+dec_mpd_radix(PyObject *self, PyObject *Py_UNUSED(dummy))
+{
+    decimal_state *state = get_module_state_by_def(Py_TYPE(self));
+    return _dec_mpd_radix(state);
+}
+
 static PyObject *
 dec_mpd_qcopy_abs(PyObject *self, PyObject *Py_UNUSED(dummy))
 {
@@ -5138,7 +5161,7 @@ ctx_##MPDFUNC(PyObject *context, PyObject *v)            \
                                                          \
     CONVERT_OP_RAISE(&a, v, context);                    \
     decimal_state *state =                               \
-        get_module_state_by_def(Py_TYPE(context));       \
+        get_module_state_from_ctx(context);              \
     if ((result = dec_alloc(state)) == NULL) {           \
         Py_DECREF(a);                                    \
         return NULL;                                     \
@@ -5170,7 +5193,7 @@ ctx_##MPDFUNC(PyObject *context, PyObject *args)                 \
                                                                  \
     CONVERT_BINOP_RAISE(&a, &b, v, w, context);                  \
     decimal_state *state =                                       \
-        get_module_state_by_def(Py_TYPE(context));               \
+        get_module_state_from_ctx(context);                      \
     if ((result = dec_alloc(state)) == NULL) {                   \
         Py_DECREF(a);                                            \
         Py_DECREF(b);                                            \
@@ -5206,7 +5229,7 @@ ctx_##MPDFUNC(PyObject *context, PyObject *args) \
                                                  \
     CONVERT_BINOP_RAISE(&a, &b, v, w, context);  \
     decimal_state *state =                       \
-        get_module_state_by_def(Py_TYPE(context)); \
+        get_module_state_from_ctx(context);      \
     if ((result = dec_alloc(state)) == NULL) {   \
         Py_DECREF(a);                            \
         Py_DECREF(b);                            \
@@ -5235,7 +5258,7 @@ ctx_##MPDFUNC(PyObject *context, PyObject *args)                         \
     }                                                                    \
                                                                          \
     CONVERT_TERNOP_RAISE(&a, &b, &c, v, w, x, context);                  \
-    decimal_state *state = get_module_state_by_def(Py_TYPE(context));    \
+    decimal_state *state = get_module_state_from_ctx(context);           \
     if ((result = dec_alloc(state)) == NULL) {                           \
         Py_DECREF(a);                                                    \
         Py_DECREF(b);                                                    \
@@ -5301,7 +5324,7 @@ ctx_mpd_qdivmod(PyObject *context, PyObject *args)
     }
 
     CONVERT_BINOP_RAISE(&a, &b, v, w, context);
-    decimal_state *state = get_module_state_by_def(Py_TYPE(context));
+    decimal_state *state = get_module_state_from_ctx(context);
     q = dec_alloc(state);
     if (q == NULL) {
         Py_DECREF(a);
@@ -5356,7 +5379,7 @@ ctx_mpd_qpow(PyObject *context, PyObject *args, PyObject *kwds)
         }
     }
 
-    decimal_state *state = get_module_state_by_def(Py_TYPE(context));
+    decimal_state *state = get_module_state_from_ctx(context);
     result = dec_alloc(state);
     if (result == NULL) {
         Py_DECREF(a);
@@ -5391,7 +5414,8 @@ DecCtx_TernaryFunc(mpd_qfma)
 static PyObject *
 ctx_mpd_radix(PyObject *context, PyObject *dummy)
 {
-    return dec_mpd_radix(context, dummy);
+    decimal_state *state = get_module_state_from_ctx(context);
+    return _dec_mpd_radix(state);
 }
 
 /* Boolean functions: single decimal argument */
@@ -5408,7 +5432,7 @@ DecCtx_BoolFunc_NO_CTX(mpd_iszero)
 static PyObject *
 ctx_iscanonical(PyObject *context, PyObject *v)
 {
-    decimal_state *state = get_module_state_by_def(Py_TYPE(context));
+    decimal_state *state = get_module_state_from_ctx(context);
     if (!PyDec_Check(state, v)) {
         PyErr_SetString(PyExc_TypeError,
             "argument must be a Decimal");
@@ -5434,7 +5458,7 @@ PyDecContext_Apply(PyObject *context, PyObject *v)
 static PyObject *
 ctx_canonical(PyObject *context, PyObject *v)
 {
-    decimal_state *state = get_module_state_by_def(Py_TYPE(context));
+    decimal_state *state = get_module_state_from_ctx(context);
     if (!PyDec_Check(state, v)) {
         PyErr_SetString(PyExc_TypeError,
             "argument must be a Decimal");
@@ -5451,7 +5475,7 @@ ctx_mpd_qcopy_abs(PyObject *context, PyObject *v)
     uint32_t status = 0;
 
     CONVERT_OP_RAISE(&a, v, context);
-    decimal_state *state = get_module_state_by_def(Py_TYPE(context));
+    decimal_state *state = get_module_state_from_ctx(context);
     result = dec_alloc(state);
     if (result == NULL) {
         Py_DECREF(a);
@@ -5484,7 +5508,7 @@ ctx_mpd_qcopy_negate(PyObject *context, PyObject *v)
     uint32_t status = 0;
 
     CONVERT_OP_RAISE(&a, v, context);
-    decimal_state *state = get_module_state_by_def(Py_TYPE(context));
+    decimal_state *state = get_module_state_from_ctx(context);
     result = dec_alloc(state);
     if (result == NULL) {
         Py_DECREF(a);
@@ -5581,7 +5605,7 @@ ctx_mpd_qcopy_sign(PyObject *context, PyObject *args)
     }
 
     CONVERT_BINOP_RAISE(&a, &b, v, w, context);
-    decimal_state *state = get_module_state_by_def(Py_TYPE(context));
+    decimal_state *state = get_module_state_from_ctx(context);
     result = dec_alloc(state);
     if (result == NULL) {
         Py_DECREF(a);
@@ -5736,6 +5760,7 @@ static PyMethodDef context_methods [] =
 };
 
 static PyType_Slot context_slots[] = {
+    {Py_tp_token, Py_TP_USE_SPEC},
     {Py_tp_dealloc, context_dealloc},
     {Py_tp_traverse, context_traverse},
     {Py_tp_clear, context_clear},