From: Dennis Sweeney <36520290+sweeneyde@users.noreply.github.com> Date: Sun, 13 Feb 2022 10:29:42 +0000 (-0500) Subject: bpo-46615: Don't crash when set operations mutate the sets (GH-31120) (GH-31312) X-Git-Tag: v3.9.11~80 X-Git-Url: http://git.ipfire.org/gitweb.cgi?a=commitdiff_plain;h=c31b8a97a8a7e8255231c9e12ed581c6240c0d6c;p=thirdparty%2FPython%2Fcpython.git bpo-46615: Don't crash when set operations mutate the sets (GH-31120) (GH-31312) Ensure strong references are acquired whenever using `set_next()`. Added randomized test cases for `__eq__` methods that sometimes mutate sets when called. (cherry picked from commit 4a66615ba736f84eadf9456bfd5d32a94cccf117) --- diff --git a/Lib/test/test_set.py b/Lib/test/test_set.py index ca2f4e28ae80..ec7433c87d41 100644 --- a/Lib/test/test_set.py +++ b/Lib/test/test_set.py @@ -1799,6 +1799,192 @@ class TestWeirdBugs(unittest.TestCase): s = {0} s.update(other) + +class TestOperationsMutating: + """Regression test for bpo-46615""" + + constructor1 = None + constructor2 = None + + def make_sets_of_bad_objects(self): + class Bad: + def __eq__(self, other): + if not enabled: + return False + if randrange(20) == 0: + set1.clear() + if randrange(20) == 0: + set2.clear() + return bool(randrange(2)) + def __hash__(self): + return randrange(2) + # Don't behave poorly during construction. + enabled = False + set1 = self.constructor1(Bad() for _ in range(randrange(50))) + set2 = self.constructor2(Bad() for _ in range(randrange(50))) + # Now start behaving poorly + enabled = True + return set1, set2 + + def check_set_op_does_not_crash(self, function): + for _ in range(100): + set1, set2 = self.make_sets_of_bad_objects() + try: + function(set1, set2) + except RuntimeError as e: + # Just make sure we don't crash here. + self.assertIn("changed size during iteration", str(e)) + + +class TestBinaryOpsMutating(TestOperationsMutating): + + def test_eq_with_mutation(self): + self.check_set_op_does_not_crash(lambda a, b: a == b) + + def test_ne_with_mutation(self): + self.check_set_op_does_not_crash(lambda a, b: a != b) + + def test_lt_with_mutation(self): + self.check_set_op_does_not_crash(lambda a, b: a < b) + + def test_le_with_mutation(self): + self.check_set_op_does_not_crash(lambda a, b: a <= b) + + def test_gt_with_mutation(self): + self.check_set_op_does_not_crash(lambda a, b: a > b) + + def test_ge_with_mutation(self): + self.check_set_op_does_not_crash(lambda a, b: a >= b) + + def test_and_with_mutation(self): + self.check_set_op_does_not_crash(lambda a, b: a & b) + + def test_or_with_mutation(self): + self.check_set_op_does_not_crash(lambda a, b: a | b) + + def test_sub_with_mutation(self): + self.check_set_op_does_not_crash(lambda a, b: a - b) + + def test_xor_with_mutation(self): + self.check_set_op_does_not_crash(lambda a, b: a ^ b) + + def test_iadd_with_mutation(self): + def f(a, b): + a &= b + self.check_set_op_does_not_crash(f) + + def test_ior_with_mutation(self): + def f(a, b): + a |= b + self.check_set_op_does_not_crash(f) + + def test_isub_with_mutation(self): + def f(a, b): + a -= b + self.check_set_op_does_not_crash(f) + + def test_ixor_with_mutation(self): + def f(a, b): + a ^= b + self.check_set_op_does_not_crash(f) + + def test_iteration_with_mutation(self): + def f1(a, b): + for x in a: + pass + for y in b: + pass + def f2(a, b): + for y in b: + pass + for x in a: + pass + def f3(a, b): + for x, y in zip(a, b): + pass + self.check_set_op_does_not_crash(f1) + self.check_set_op_does_not_crash(f2) + self.check_set_op_does_not_crash(f3) + + +class TestBinaryOpsMutating_Set_Set(TestBinaryOpsMutating, unittest.TestCase): + constructor1 = set + constructor2 = set + +class TestBinaryOpsMutating_Subclass_Subclass(TestBinaryOpsMutating, unittest.TestCase): + constructor1 = SetSubclass + constructor2 = SetSubclass + +class TestBinaryOpsMutating_Set_Subclass(TestBinaryOpsMutating, unittest.TestCase): + constructor1 = set + constructor2 = SetSubclass + +class TestBinaryOpsMutating_Subclass_Set(TestBinaryOpsMutating, unittest.TestCase): + constructor1 = SetSubclass + constructor2 = set + + +class TestMethodsMutating(TestOperationsMutating): + + def test_issubset_with_mutation(self): + self.check_set_op_does_not_crash(set.issubset) + + def test_issuperset_with_mutation(self): + self.check_set_op_does_not_crash(set.issuperset) + + def test_intersection_with_mutation(self): + self.check_set_op_does_not_crash(set.intersection) + + def test_union_with_mutation(self): + self.check_set_op_does_not_crash(set.union) + + def test_difference_with_mutation(self): + self.check_set_op_does_not_crash(set.difference) + + def test_symmetric_difference_with_mutation(self): + self.check_set_op_does_not_crash(set.symmetric_difference) + + def test_isdisjoint_with_mutation(self): + self.check_set_op_does_not_crash(set.isdisjoint) + + def test_difference_update_with_mutation(self): + self.check_set_op_does_not_crash(set.difference_update) + + def test_intersection_update_with_mutation(self): + self.check_set_op_does_not_crash(set.intersection_update) + + def test_symmetric_difference_update_with_mutation(self): + self.check_set_op_does_not_crash(set.symmetric_difference_update) + + def test_update_with_mutation(self): + self.check_set_op_does_not_crash(set.update) + + +class TestMethodsMutating_Set_Set(TestMethodsMutating, unittest.TestCase): + constructor1 = set + constructor2 = set + +class TestMethodsMutating_Subclass_Subclass(TestMethodsMutating, unittest.TestCase): + constructor1 = SetSubclass + constructor2 = SetSubclass + +class TestMethodsMutating_Set_Subclass(TestMethodsMutating, unittest.TestCase): + constructor1 = set + constructor2 = SetSubclass + +class TestMethodsMutating_Subclass_Set(TestMethodsMutating, unittest.TestCase): + constructor1 = SetSubclass + constructor2 = set + +class TestMethodsMutating_Set_Dict(TestMethodsMutating, unittest.TestCase): + constructor1 = set + constructor2 = dict.fromkeys + +class TestMethodsMutating_Set_List(TestMethodsMutating, unittest.TestCase): + constructor1 = set + constructor2 = list + + # Application tests (based on David Eppstein's graph recipes ==================================== def powerset(U): diff --git a/Misc/NEWS.d/next/Core and Builtins/2022-02-04-04-33-18.bpo-46615.puArY9.rst b/Misc/NEWS.d/next/Core and Builtins/2022-02-04-04-33-18.bpo-46615.puArY9.rst new file mode 100644 index 000000000000..6dee92a546e3 --- /dev/null +++ b/Misc/NEWS.d/next/Core and Builtins/2022-02-04-04-33-18.bpo-46615.puArY9.rst @@ -0,0 +1 @@ +When iterating over sets internally in ``setobject.c``, acquire strong references to the resulting items from the set. This prevents crashes in corner-cases of various set operations where the set gets mutated. diff --git a/Objects/setobject.c b/Objects/setobject.c index 4bd5777f967d..6d156bd4e082 100644 --- a/Objects/setobject.c +++ b/Objects/setobject.c @@ -1207,17 +1207,21 @@ set_intersection(PySetObject *so, PyObject *other) while (set_next((PySetObject *)other, &pos, &entry)) { key = entry->key; hash = entry->hash; + Py_INCREF(key); rv = set_contains_entry(so, key, hash); if (rv < 0) { Py_DECREF(result); + Py_DECREF(key); return NULL; } if (rv) { if (set_add_entry(result, key, hash)) { Py_DECREF(result); + Py_DECREF(key); return NULL; } } + Py_DECREF(key); } return (PyObject *)result; } @@ -1357,11 +1361,16 @@ set_isdisjoint(PySetObject *so, PyObject *other) other = tmp; } while (set_next((PySetObject *)other, &pos, &entry)) { - rv = set_contains_entry(so, entry->key, entry->hash); - if (rv < 0) + PyObject *key = entry->key; + Py_INCREF(key); + rv = set_contains_entry(so, key, entry->hash); + Py_DECREF(key); + if (rv < 0) { return NULL; - if (rv) + } + if (rv) { Py_RETURN_FALSE; + } } Py_RETURN_TRUE; } @@ -1420,11 +1429,16 @@ set_difference_update_internal(PySetObject *so, PyObject *other) Py_INCREF(other); } - while (set_next((PySetObject *)other, &pos, &entry)) - if (set_discard_entry(so, entry->key, entry->hash) < 0) { + while (set_next((PySetObject *)other, &pos, &entry)) { + PyObject *key = entry->key; + Py_INCREF(key); + if (set_discard_entry(so, key, entry->hash) < 0) { Py_DECREF(other); + Py_DECREF(key); return -1; } + Py_DECREF(key); + } Py_DECREF(other); } else { @@ -1515,17 +1529,21 @@ set_difference(PySetObject *so, PyObject *other) while (set_next(so, &pos, &entry)) { key = entry->key; hash = entry->hash; + Py_INCREF(key); rv = _PyDict_Contains(other, key, hash); if (rv < 0) { Py_DECREF(result); + Py_DECREF(key); return NULL; } if (!rv) { if (set_add_entry((PySetObject *)result, key, hash)) { Py_DECREF(result); + Py_DECREF(key); return NULL; } } + Py_DECREF(key); } return result; } @@ -1534,17 +1552,21 @@ set_difference(PySetObject *so, PyObject *other) while (set_next(so, &pos, &entry)) { key = entry->key; hash = entry->hash; + Py_INCREF(key); rv = set_contains_entry((PySetObject *)other, key, hash); if (rv < 0) { Py_DECREF(result); + Py_DECREF(key); return NULL; } if (!rv) { if (set_add_entry((PySetObject *)result, key, hash)) { Py_DECREF(result); + Py_DECREF(key); return NULL; } } + Py_DECREF(key); } return result; } @@ -1641,17 +1663,21 @@ set_symmetric_difference_update(PySetObject *so, PyObject *other) while (set_next(otherset, &pos, &entry)) { key = entry->key; hash = entry->hash; + Py_INCREF(key); rv = set_discard_entry(so, key, hash); if (rv < 0) { Py_DECREF(otherset); + Py_DECREF(key); return NULL; } if (rv == DISCARD_NOTFOUND) { if (set_add_entry(so, key, hash)) { Py_DECREF(otherset); + Py_DECREF(key); return NULL; } } + Py_DECREF(key); } Py_DECREF(otherset); Py_RETURN_NONE; @@ -1726,11 +1752,16 @@ set_issubset(PySetObject *so, PyObject *other) Py_RETURN_FALSE; while (set_next(so, &pos, &entry)) { - rv = set_contains_entry((PySetObject *)other, entry->key, entry->hash); - if (rv < 0) + PyObject *key = entry->key; + Py_INCREF(key); + rv = set_contains_entry((PySetObject *)other, key, entry->hash); + Py_DECREF(key); + if (rv < 0) { return NULL; - if (!rv) + } + if (!rv) { Py_RETURN_FALSE; + } } Py_RETURN_TRUE; }