]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
gh-116621: Set manual critical section for list.extend (gh-116657)
authorDonghee Na <donghee.na@python.org>
Tue, 12 Mar 2024 22:28:23 +0000 (07:28 +0900)
committerGitHub <noreply@github.com>
Tue, 12 Mar 2024 22:28:23 +0000 (07:28 +0900)
Objects/clinic/listobject.c.h
Objects/listobject.c

index b90dc0a0b463c638aa78b80944a1b2c5af97daae..588e021fb71fd39e4de9ed07236a5b97d287a613 100644 (file)
@@ -125,29 +125,14 @@ list_append(PyListObject *self, PyObject *object)
     return return_value;
 }
 
-PyDoc_STRVAR(py_list_extend__doc__,
+PyDoc_STRVAR(list_extend__doc__,
 "extend($self, iterable, /)\n"
 "--\n"
 "\n"
 "Extend list by appending elements from the iterable.");
 
-#define PY_LIST_EXTEND_METHODDEF    \
-    {"extend", (PyCFunction)py_list_extend, METH_O, py_list_extend__doc__},
-
-static PyObject *
-py_list_extend_impl(PyListObject *self, PyObject *iterable);
-
-static PyObject *
-py_list_extend(PyListObject *self, PyObject *iterable)
-{
-    PyObject *return_value = NULL;
-
-    Py_BEGIN_CRITICAL_SECTION2(self, iterable);
-    return_value = py_list_extend_impl(self, iterable);
-    Py_END_CRITICAL_SECTION2();
-
-    return return_value;
-}
+#define LIST_EXTEND_METHODDEF    \
+    {"extend", (PyCFunction)list_extend, METH_O, list_extend__doc__},
 
 PyDoc_STRVAR(list_pop__doc__,
 "pop($self, index=-1, /)\n"
@@ -454,4 +439,4 @@ list___reversed__(PyListObject *self, PyObject *Py_UNUSED(ignored))
 {
     return list___reversed___impl(self);
 }
-/*[clinic end generated code: output=a77eda9931ec0c20 input=a9049054013a1b77]*/
+/*[clinic end generated code: output=854957a1d4a89bbd input=a9049054013a1b77]*/
index 164f363efe24f05440b867ae20c835d7023475d1..759902c06b4ef3603b7dc70865f3429095c40557 100644 (file)
@@ -10,6 +10,7 @@
 #include "pycore_modsupport.h"    // _PyArg_NoKwnames()
 #include "pycore_object.h"        // _PyObject_GC_TRACK(), _PyDebugAllocatorStats()
 #include "pycore_tuple.h"         // _PyTuple_FromArray()
+#include "pycore_setobject.h"     // _PySet_NextEntry()
 #include <stddef.h>
 
 /*[clinic input]
@@ -994,26 +995,28 @@ PyList_SetSlice(PyObject *a, Py_ssize_t ilow, Py_ssize_t ihigh, PyObject *v)
     return list_ass_slice((PyListObject *)a, ilow, ihigh, v);
 }
 
-static PyObject *
+static int
 list_inplace_repeat_lock_held(PyListObject *self, Py_ssize_t n)
 {
     Py_ssize_t input_size = PyList_GET_SIZE(self);
     if (input_size == 0 || n == 1) {
-        return Py_NewRef(self);
+        return 0;
     }
 
     if (n < 1) {
         list_clear(self);
-        return Py_NewRef(self);
+        return 0;
     }
 
     if (input_size > PY_SSIZE_T_MAX / n) {
-        return PyErr_NoMemory();
+        PyErr_NoMemory();
+        return -1;
     }
     Py_ssize_t output_size = input_size * n;
 
-    if (list_resize(self, output_size) < 0)
-        return NULL;
+    if (list_resize(self, output_size) < 0) {
+        return -1;
+    }
 
     PyObject **items = self->ob_item;
     for (Py_ssize_t j = 0; j < input_size; j++) {
@@ -1021,8 +1024,7 @@ list_inplace_repeat_lock_held(PyListObject *self, Py_ssize_t n)
     }
     _Py_memory_repeat((char *)items, sizeof(PyObject *)*output_size,
                       sizeof(PyObject *)*input_size);
-
-    return Py_NewRef(self);
+    return 0;
 }
 
 static PyObject *
@@ -1031,7 +1033,12 @@ list_inplace_repeat(PyObject *_self, Py_ssize_t n)
     PyObject *ret;
     PyListObject *self = (PyListObject *) _self;
     Py_BEGIN_CRITICAL_SECTION(self);
-    ret = list_inplace_repeat_lock_held(self, n);
+    if (list_inplace_repeat_lock_held(self, n) < 0) {
+        ret = NULL;
+    }
+    else {
+        ret = Py_NewRef(self);
+    }
     Py_END_CRITICAL_SECTION();
     return ret;
 }
@@ -1179,7 +1186,7 @@ list_extend_fast(PyListObject *self, PyObject *iterable)
 }
 
 static int
-list_extend_iter(PyListObject *self, PyObject *iterable)
+list_extend_iter_lock_held(PyListObject *self, PyObject *iterable)
 {
     PyObject *it = PyObject_GetIter(iterable);
     if (it == NULL) {
@@ -1253,45 +1260,78 @@ list_extend_iter(PyListObject *self, PyObject *iterable)
     return -1;
 }
 
-
 static int
-list_extend(PyListObject *self, PyObject *iterable)
+list_extend_lock_held(PyListObject *self, PyObject *iterable)
 {
-    // Special cases:
-    // 1) lists and tuples which can use PySequence_Fast ops
-    // 2) extending self to self requires making a copy first
-    if (PyList_CheckExact(iterable)
-        || PyTuple_CheckExact(iterable)
-        || (PyObject *)self == iterable)
-    {
-        iterable = PySequence_Fast(iterable, "argument must be iterable");
-        if (!iterable) {
-            return -1;
-        }
-
-        int res = list_extend_fast(self, iterable);
-        Py_DECREF(iterable);
-        return res;
-    }
-    else {
-        return list_extend_iter(self, iterable);
+    PyObject *seq = PySequence_Fast(iterable, "argument must be iterable");
+    if (!seq) {
+        return -1;
     }
-}
 
+    int res = list_extend_fast(self, seq);
+    Py_DECREF(seq);
+    return res;
+}
 
-PyObject *
-_PyList_Extend(PyListObject *self, PyObject *iterable)
+static int
+list_extend_set(PyListObject *self, PySetObject *other)
 {
-    if (list_extend(self, iterable) < 0) {
-        return NULL;
+    Py_ssize_t m = Py_SIZE(self);
+    Py_ssize_t n = PySet_GET_SIZE(other);
+    if (list_resize(self, m + n) < 0) {
+        return -1;
     }
-    Py_RETURN_NONE;
+    /* populate the end of self with iterable's items */
+    Py_ssize_t setpos = 0;
+    Py_hash_t hash;
+    PyObject *key;
+    PyObject **dest = self->ob_item + m;
+    while (_PySet_NextEntry((PyObject *)other, &setpos, &key, &hash)) {
+        Py_INCREF(key);
+        *dest = key;
+        dest++;
+    }
+    Py_SET_SIZE(self, m + n);
+    return 0;
 }
 
+static int
+_list_extend(PyListObject *self, PyObject *iterable)
+{
+    // Special case:
+    // lists and tuples which can use PySequence_Fast ops
+    // TODO(@corona10): Add more special cases for other types.
+    int res = -1;
+    if ((PyObject *)self == iterable) {
+        Py_BEGIN_CRITICAL_SECTION(self);
+        res = list_inplace_repeat_lock_held(self, 2);
+        Py_END_CRITICAL_SECTION();
+    }
+    else if (PyList_CheckExact(iterable)) {
+        Py_BEGIN_CRITICAL_SECTION2(self, iterable);
+        res = list_extend_lock_held(self, iterable);
+        Py_END_CRITICAL_SECTION2();
+    }
+    else if (PyTuple_CheckExact(iterable)) {
+        Py_BEGIN_CRITICAL_SECTION(self);
+        res = list_extend_lock_held(self, iterable);
+        Py_END_CRITICAL_SECTION();
+    }
+    else if (PyAnySet_CheckExact(iterable)) {
+        Py_BEGIN_CRITICAL_SECTION2(self, iterable);
+        res = list_extend_set(self, (PySetObject *)iterable);
+        Py_END_CRITICAL_SECTION2();
+    }
+    else {
+        Py_BEGIN_CRITICAL_SECTION(self);
+        res = list_extend_iter_lock_held(self, iterable);
+        Py_END_CRITICAL_SECTION();
+    }
+    return res;
+}
 
 /*[clinic input]
-@critical_section self iterable
-list.extend as py_list_extend
+list.extend as list_extend
 
      iterable: object
      /
@@ -1300,12 +1340,20 @@ Extend list by appending elements from the iterable.
 [clinic start generated code]*/
 
 static PyObject *
-py_list_extend_impl(PyListObject *self, PyObject *iterable)
-/*[clinic end generated code: output=a2f115ceace2c845 input=1d42175414e1a5f3]*/
+list_extend(PyListObject *self, PyObject *iterable)
+/*[clinic end generated code: output=630fb3bca0c8e789 input=979da7597a515791]*/
 {
-    return _PyList_Extend(self, iterable);
+    if (_list_extend(self, iterable) < 0) {
+        return NULL;
+    }
+    Py_RETURN_NONE;
 }
 
+PyObject *
+_PyList_Extend(PyListObject *self, PyObject *iterable)
+{
+    return list_extend(self, iterable);
+}
 
 int
 PyList_Extend(PyObject *self, PyObject *iterable)
@@ -1314,7 +1362,7 @@ PyList_Extend(PyObject *self, PyObject *iterable)
         PyErr_BadInternalCall();
         return -1;
     }
-    return list_extend((PyListObject*)self, iterable);
+    return _list_extend((PyListObject*)self, iterable);
 }
 
 
@@ -1334,7 +1382,7 @@ static PyObject *
 list_inplace_concat(PyObject *_self, PyObject *other)
 {
     PyListObject *self = (PyListObject *)_self;
-    if (list_extend(self, other) < 0) {
+    if (_list_extend(self, other) < 0) {
         return NULL;
     }
     return Py_NewRef(self);
@@ -3168,7 +3216,7 @@ list___init___impl(PyListObject *self, PyObject *iterable)
         list_clear(self);
     }
     if (iterable != NULL) {
-        if (list_extend(self, iterable) < 0) {
+        if (_list_extend(self, iterable) < 0) {
             return -1;
         }
     }
@@ -3229,7 +3277,7 @@ static PyMethodDef list_methods[] = {
     LIST_COPY_METHODDEF
     LIST_APPEND_METHODDEF
     LIST_INSERT_METHODDEF
-    PY_LIST_EXTEND_METHODDEF
+    LIST_EXTEND_METHODDEF
     LIST_POP_METHODDEF
     LIST_REMOVE_METHODDEF
     LIST_INDEX_METHODDEF