From: Mike Bayer Date: Thu, 2 May 2024 23:05:08 +0000 (-0400) Subject: correctly apply _set_binops_check_strict to AssociationProxy X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=01fbe18d5cb3009400d38a5d1d67f62ae4bfacc0;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git correctly apply _set_binops_check_strict to AssociationProxy Revised the set "binary" operators for the association proxy ``set()`` interface to correctly raise ``TypeError`` for invalid use of the ``|``, ``&``, ``^``, and ``-`` operators, as well as the in-place mutation versions of these methods, to match the behavior of standard Python ``set()`` as well as SQLAlchemy ORM's "intstrumented" set implementation. Fixes: #11349 Change-Id: I02442f8885107d115b7ecfa1ca716835a55d4db3 --- diff --git a/doc/build/changelog/unreleased_21/11349.rst b/doc/build/changelog/unreleased_21/11349.rst new file mode 100644 index 0000000000..244713e9e3 --- /dev/null +++ b/doc/build/changelog/unreleased_21/11349.rst @@ -0,0 +1,11 @@ +.. change:: + :tags: bug, orm + :tickets: 11349 + + Revised the set "binary" operators for the association proxy ``set()`` + interface to correctly raise ``TypeError`` for invalid use of the ``|``, + ``&``, ``^``, and ``-`` operators, as well as the in-place mutation + versions of these methods, to match the behavior of standard Python + ``set()`` as well as SQLAlchemy ORM's "intstrumented" set implementation. + + diff --git a/lib/sqlalchemy/ext/associationproxy.py b/lib/sqlalchemy/ext/associationproxy.py index 5651b1c56f..ef146f78f1 100644 --- a/lib/sqlalchemy/ext/associationproxy.py +++ b/lib/sqlalchemy/ext/associationproxy.py @@ -1873,7 +1873,7 @@ class _AssociationSet(_AssociationSingleItem[_T], MutableSet[_T]): self, other: AbstractSet[_S] ) -> MutableSet[Union[_T, _S]]: if not collections._set_binops_check_strict(self, other): - raise NotImplementedError() + return NotImplemented for value in other: self.add(value) return self @@ -1885,12 +1885,16 @@ class _AssociationSet(_AssociationSingleItem[_T], MutableSet[_T]): return set(self).union(*s) def __or__(self, __s: AbstractSet[_S]) -> MutableSet[Union[_T, _S]]: + if not collections._set_binops_check_strict(self, __s): + return NotImplemented return self.union(__s) def difference(self, *s: Iterable[Any]) -> MutableSet[_T]: return set(self).difference(*s) def __sub__(self, s: AbstractSet[Any]) -> MutableSet[_T]: + if not collections._set_binops_check_strict(self, s): + return NotImplemented return self.difference(s) def difference_update(self, *s: Iterable[Any]) -> None: @@ -1900,7 +1904,7 @@ class _AssociationSet(_AssociationSingleItem[_T], MutableSet[_T]): def __isub__(self, s: AbstractSet[Any]) -> Self: if not collections._set_binops_check_strict(self, s): - raise NotImplementedError() + return NotImplemented for value in s: self.discard(value) return self @@ -1909,6 +1913,8 @@ class _AssociationSet(_AssociationSingleItem[_T], MutableSet[_T]): return set(self).intersection(*s) def __and__(self, s: AbstractSet[Any]) -> MutableSet[_T]: + if not collections._set_binops_check_strict(self, s): + return NotImplemented return self.intersection(s) def intersection_update(self, *s: Iterable[Any]) -> None: @@ -1924,7 +1930,7 @@ class _AssociationSet(_AssociationSingleItem[_T], MutableSet[_T]): def __iand__(self, s: AbstractSet[Any]) -> Self: if not collections._set_binops_check_strict(self, s): - raise NotImplementedError() + return NotImplemented want = self.intersection(s) have: Set[_T] = set(self) @@ -1940,6 +1946,8 @@ class _AssociationSet(_AssociationSingleItem[_T], MutableSet[_T]): return set(self).symmetric_difference(__s) def __xor__(self, s: AbstractSet[_S]) -> MutableSet[Union[_T, _S]]: + if not collections._set_binops_check_strict(self, s): + return NotImplemented return self.symmetric_difference(s) def symmetric_difference_update(self, other: Iterable[Any]) -> None: @@ -1954,7 +1962,7 @@ class _AssociationSet(_AssociationSingleItem[_T], MutableSet[_T]): def __ixor__(self, other: AbstractSet[_S]) -> MutableSet[Union[_T, _S]]: # type: ignore # noqa: E501 if not collections._set_binops_check_strict(self, other): - raise NotImplementedError() + return NotImplemented self.symmetric_difference_update(other) return self diff --git a/lib/sqlalchemy/orm/collections.py b/lib/sqlalchemy/orm/collections.py index d112680df6..394a4eaba5 100644 --- a/lib/sqlalchemy/orm/collections.py +++ b/lib/sqlalchemy/orm/collections.py @@ -1371,14 +1371,6 @@ def _set_binops_check_strict(self: Any, obj: Any) -> bool: return isinstance(obj, _set_binop_bases + (self.__class__,)) -def _set_binops_check_loose(self: Any, obj: Any) -> bool: - """Allow anything set-like to participate in set binops.""" - return ( - isinstance(obj, _set_binop_bases + (self.__class__,)) - or util.duck_type_collection(obj) == set - ) - - def _set_decorators() -> Dict[str, Callable[[_FN], _FN]]: """Tailored instrumentation wrappers for any set-like class.""" diff --git a/test/ext/test_associationproxy.py b/test/ext/test_associationproxy.py index 7e2b31a9b5..1aca0c97e2 100644 --- a/test/ext/test_associationproxy.py +++ b/test/ext/test_associationproxy.py @@ -40,6 +40,7 @@ from sqlalchemy.testing import assert_raises from sqlalchemy.testing import assert_raises_message from sqlalchemy.testing import AssertsCompiledSQL from sqlalchemy.testing import eq_ +from sqlalchemy.testing import expect_raises from sqlalchemy.testing import expect_warnings from sqlalchemy.testing import fixtures from sqlalchemy.testing import is_ @@ -735,6 +736,63 @@ class SetTest(_CollectionOperations): assert_raises(TypeError, set, [p1.children]) + def test_special_binops_checks(self): + """test for #11349""" + + Parent = self.classes.Parent + + p1 = Parent("P1") + p1.children = ["a", "b", "c"] + control = {"a", "b", "c"} + + with expect_raises(TypeError): + control | ["c", "d"] + + with expect_raises(TypeError): + p1.children | ["c", "d"] + + with expect_raises(TypeError): + control |= ["c", "d"] + + with expect_raises(TypeError): + p1.children |= ["c", "d"] + + with expect_raises(TypeError): + control & ["c", "d"] + + with expect_raises(TypeError): + p1.children & ["c", "d"] + + with expect_raises(TypeError): + control &= ["c", "d"] + + with expect_raises(TypeError): + p1.children &= ["c", "d"] + + with expect_raises(TypeError): + control ^ ["c", "d"] + + with expect_raises(TypeError): + p1.children ^ ["c", "d"] + + with expect_raises(TypeError): + control ^= ["c", "d"] + + with expect_raises(TypeError): + p1.children ^= ["c", "d"] + + with expect_raises(TypeError): + control - ["c", "d"] + + with expect_raises(TypeError): + p1.children - ["c", "d"] + + with expect_raises(TypeError): + control -= ["c", "d"] + + with expect_raises(TypeError): + p1.children -= ["c", "d"] + def test_set_comparisons(self): Parent = self.classes.Parent diff --git a/test/orm/test_collection.py b/test/orm/test_collection.py index 3afc79c918..d07dadb239 100644 --- a/test/orm/test_collection.py +++ b/test/orm/test_collection.py @@ -28,6 +28,7 @@ from sqlalchemy.orm.collections import collection from sqlalchemy.testing import assert_raises from sqlalchemy.testing import assert_raises_message from sqlalchemy.testing import eq_ +from sqlalchemy.testing import expect_raises from sqlalchemy.testing import expect_raises_message from sqlalchemy.testing import expect_warnings from sqlalchemy.testing import fixtures @@ -866,11 +867,11 @@ class CollectionsTest(OrderedDictFixture, fixtures.ORMTest): control |= values assert_eq() - try: + with expect_raises(TypeError): + control |= [e, creator()] + + with expect_raises(TypeError): direct |= [e, creator()] - assert False - except TypeError: - assert True addall(creator(), creator()) direct.clear() @@ -924,11 +925,11 @@ class CollectionsTest(OrderedDictFixture, fixtures.ORMTest): control -= values assert_eq() - try: + with expect_raises(TypeError): + control -= [e, creator()] + + with expect_raises(TypeError): direct -= [e, creator()] - assert False - except TypeError: - assert True if hasattr(direct, "intersection_update"): zap() @@ -965,11 +966,11 @@ class CollectionsTest(OrderedDictFixture, fixtures.ORMTest): control &= values assert_eq() - try: + with expect_raises(TypeError): + control &= [e, creator()] + + with expect_raises(TypeError): direct &= [e, creator()] - assert False - except TypeError: - assert True if hasattr(direct, "symmetric_difference_update"): zap() @@ -1020,11 +1021,11 @@ class CollectionsTest(OrderedDictFixture, fixtures.ORMTest): control ^= values assert_eq() - try: + with expect_raises(TypeError): + control ^= [e, creator()] + + with expect_raises(TypeError): direct ^= [e, creator()] - assert False - except TypeError: - assert True def _test_set_bulk(self, typecallable, creator=None): if creator is None: