with threading_helper.start_threads(threads):
pass
- for thread in threads:
- threading_helper.join_thread(thread)
-
# hard errors
check([clear] + [reduce] * 10)
check([clear] + [upper] * 10, bytearray(b'a' * 0x400000))
check([clear] + [zfill] * 10, bytearray(b'1' * 0x200000))
+ @unittest.skipUnless(support.Py_GIL_DISABLED, 'this test can only possibly fail with GIL disabled')
+ @threading_helper.reap_threads
+ @threading_helper.requires_working_threading()
+ def test_free_threading_bytearrayiter(self):
+ # Non-deterministic but good chance to fail if bytearrayiter is not free-threading safe.
+ # We are fishing for a "Assertion failed: object has negative ref count" and tsan races.
+
+ def iter_next(b, it):
+ b.wait()
+ list(it)
+
+ def iter_reduce(b, it):
+ b.wait()
+ it.__reduce__()
+
+ def iter_setstate(b, it):
+ b.wait()
+ it.__setstate__(0)
+
+ def check(funcs, it):
+ barrier = threading.Barrier(len(funcs))
+ threads = []
+
+ for func in funcs:
+ thread = threading.Thread(target=func, args=(barrier, it))
+
+ threads.append(thread)
+
+ with threading_helper.start_threads(threads):
+ pass
+
+ for _ in range(10):
+ ba = bytearray(b'0' * 0x4000) # this is a load-bearing variable, do not remove
+
+ check([iter_next] * 10, iter(ba))
+ check([iter_next] + [iter_reduce] * 10, iter(ba)) # for tsan
+ check([iter_next] + [iter_setstate] * 10, iter(ba)) # for tsan
+
if __name__ == "__main__":
unittest.main()
bytearrayiter_next(PyObject *self)
{
bytesiterobject *it = _bytesiterobject_CAST(self);
- PyByteArrayObject *seq;
+ int val;
assert(it != NULL);
- seq = it->it_seq;
- if (seq == NULL)
+ Py_ssize_t index = FT_ATOMIC_LOAD_SSIZE_RELAXED(it->it_index);
+ if (index < 0) {
return NULL;
+ }
+ PyByteArrayObject *seq = it->it_seq;
assert(PyByteArray_Check(seq));
- if (it->it_index < PyByteArray_GET_SIZE(seq)) {
- return _PyLong_FromUnsignedChar(
- (unsigned char)PyByteArray_AS_STRING(seq)[it->it_index++]);
+ Py_BEGIN_CRITICAL_SECTION(seq);
+ if (index < Py_SIZE(seq)) {
+ val = (unsigned char)PyByteArray_AS_STRING(seq)[index];
+ }
+ else {
+ val = -1;
}
+ Py_END_CRITICAL_SECTION();
- it->it_seq = NULL;
- Py_DECREF(seq);
- return NULL;
+ if (val == -1) {
+ FT_ATOMIC_STORE_SSIZE_RELAXED(it->it_index, -1);
+#ifndef Py_GIL_DISABLED
+ Py_CLEAR(it->it_seq);
+#endif
+ return NULL;
+ }
+ FT_ATOMIC_STORE_SSIZE_RELAXED(it->it_index, index + 1);
+ return _PyLong_FromUnsignedChar((unsigned char)val);
}
static PyObject *
{
bytesiterobject *it = _bytesiterobject_CAST(self);
Py_ssize_t len = 0;
- if (it->it_seq) {
- len = PyByteArray_GET_SIZE(it->it_seq) - it->it_index;
+ Py_ssize_t index = FT_ATOMIC_LOAD_SSIZE_RELAXED(it->it_index);
+ if (index >= 0) {
+ len = PyByteArray_GET_SIZE(it->it_seq) - index;
if (len < 0) {
len = 0;
}
* call must be before access of iterator pointers.
* see issue #101765 */
bytesiterobject *it = _bytesiterobject_CAST(self);
- if (it->it_seq != NULL) {
- return Py_BuildValue("N(O)n", iter, it->it_seq, it->it_index);
- } else {
- return Py_BuildValue("N(())", iter);
+ Py_ssize_t index = FT_ATOMIC_LOAD_SSIZE_RELAXED(it->it_index);
+ if (index >= 0) {
+ return Py_BuildValue("N(O)n", iter, it->it_seq, index);
}
+ return Py_BuildValue("N(())", iter);
}
static PyObject *
bytearrayiter_setstate(PyObject *self, PyObject *state)
{
Py_ssize_t index = PyLong_AsSsize_t(state);
- if (index == -1 && PyErr_Occurred())
+ if (index == -1 && PyErr_Occurred()) {
return NULL;
+ }
bytesiterobject *it = _bytesiterobject_CAST(self);
- if (it->it_seq != NULL) {
- if (index < 0)
- index = 0;
- else if (index > PyByteArray_GET_SIZE(it->it_seq))
- index = PyByteArray_GET_SIZE(it->it_seq); /* iterator exhausted */
- it->it_index = index;
+ if (FT_ATOMIC_LOAD_SSIZE_RELAXED(it->it_index) >= 0) {
+ if (index < -1) {
+ index = -1;
+ }
+ else {
+ Py_ssize_t size = PyByteArray_GET_SIZE(it->it_seq);
+ if (index > size) {
+ index = size; /* iterator at end */
+ }
+ }
+ FT_ATOMIC_STORE_SSIZE_RELAXED(it->it_index, index);
}
Py_RETURN_NONE;
}
it = PyObject_GC_New(bytesiterobject, &PyByteArrayIter_Type);
if (it == NULL)
return NULL;
- it->it_index = 0;
+ it->it_index = 0; // -1 indicates exhausted
it->it_seq = (PyByteArrayObject *)Py_NewRef(seq);
_PyObject_GC_TRACK(it);
return (PyObject *)it;