]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
correctly apply _set_binops_check_strict to AssociationProxy
authorMike Bayer <mike_mp@zzzcomputing.com>
Thu, 2 May 2024 23:05:08 +0000 (19:05 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Thu, 2 May 2024 23:24:34 +0000 (19:24 -0400)
Revised the set "binary" operators for the association proxy ``set()``
interface to correctly raise ``TypeError`` for invalid use of the ``|``,
``&``, ``^``, and ``-`` operators, as well as the in-place mutation
versions of these methods, to match the behavior of standard Python
``set()`` as well as SQLAlchemy ORM's "intstrumented" set implementation.

Fixes: #11349
Change-Id: I02442f8885107d115b7ecfa1ca716835a55d4db3

doc/build/changelog/unreleased_21/11349.rst [new file with mode: 0644]
lib/sqlalchemy/ext/associationproxy.py
lib/sqlalchemy/orm/collections.py
test/ext/test_associationproxy.py
test/orm/test_collection.py

diff --git a/doc/build/changelog/unreleased_21/11349.rst b/doc/build/changelog/unreleased_21/11349.rst
new file mode 100644 (file)
index 0000000..244713e
--- /dev/null
@@ -0,0 +1,11 @@
+.. change::
+    :tags: bug, orm
+    :tickets: 11349
+
+    Revised the set "binary" operators for the association proxy ``set()``
+    interface to correctly raise ``TypeError`` for invalid use of the ``|``,
+    ``&``, ``^``, and ``-`` operators, as well as the in-place mutation
+    versions of these methods, to match the behavior of standard Python
+    ``set()`` as well as SQLAlchemy ORM's "intstrumented" set implementation.
+
+
index 5651b1c56f3c284495fce008e6bd63b5627cee5f..ef146f78f16868759015216cb4e10228273a8aec 100644 (file)
@@ -1873,7 +1873,7 @@ class _AssociationSet(_AssociationSingleItem[_T], MutableSet[_T]):
         self, other: AbstractSet[_S]
     ) -> MutableSet[Union[_T, _S]]:
         if not collections._set_binops_check_strict(self, other):
-            raise NotImplementedError()
+            return NotImplemented
         for value in other:
             self.add(value)
         return self
@@ -1885,12 +1885,16 @@ class _AssociationSet(_AssociationSingleItem[_T], MutableSet[_T]):
         return set(self).union(*s)
 
     def __or__(self, __s: AbstractSet[_S]) -> MutableSet[Union[_T, _S]]:
+        if not collections._set_binops_check_strict(self, __s):
+            return NotImplemented
         return self.union(__s)
 
     def difference(self, *s: Iterable[Any]) -> MutableSet[_T]:
         return set(self).difference(*s)
 
     def __sub__(self, s: AbstractSet[Any]) -> MutableSet[_T]:
+        if not collections._set_binops_check_strict(self, s):
+            return NotImplemented
         return self.difference(s)
 
     def difference_update(self, *s: Iterable[Any]) -> None:
@@ -1900,7 +1904,7 @@ class _AssociationSet(_AssociationSingleItem[_T], MutableSet[_T]):
 
     def __isub__(self, s: AbstractSet[Any]) -> Self:
         if not collections._set_binops_check_strict(self, s):
-            raise NotImplementedError()
+            return NotImplemented
         for value in s:
             self.discard(value)
         return self
@@ -1909,6 +1913,8 @@ class _AssociationSet(_AssociationSingleItem[_T], MutableSet[_T]):
         return set(self).intersection(*s)
 
     def __and__(self, s: AbstractSet[Any]) -> MutableSet[_T]:
+        if not collections._set_binops_check_strict(self, s):
+            return NotImplemented
         return self.intersection(s)
 
     def intersection_update(self, *s: Iterable[Any]) -> None:
@@ -1924,7 +1930,7 @@ class _AssociationSet(_AssociationSingleItem[_T], MutableSet[_T]):
 
     def __iand__(self, s: AbstractSet[Any]) -> Self:
         if not collections._set_binops_check_strict(self, s):
-            raise NotImplementedError()
+            return NotImplemented
         want = self.intersection(s)
         have: Set[_T] = set(self)
 
@@ -1940,6 +1946,8 @@ class _AssociationSet(_AssociationSingleItem[_T], MutableSet[_T]):
         return set(self).symmetric_difference(__s)
 
     def __xor__(self, s: AbstractSet[_S]) -> MutableSet[Union[_T, _S]]:
+        if not collections._set_binops_check_strict(self, s):
+            return NotImplemented
         return self.symmetric_difference(s)
 
     def symmetric_difference_update(self, other: Iterable[Any]) -> None:
@@ -1954,7 +1962,7 @@ class _AssociationSet(_AssociationSingleItem[_T], MutableSet[_T]):
 
     def __ixor__(self, other: AbstractSet[_S]) -> MutableSet[Union[_T, _S]]:  # type: ignore  # noqa: E501
         if not collections._set_binops_check_strict(self, other):
-            raise NotImplementedError()
+            return NotImplemented
 
         self.symmetric_difference_update(other)
         return self
index d112680df6ef49330445d110c2ea8f1b0b89bacb..394a4eaba548f807ed5ba1100c05981a20caed19 100644 (file)
@@ -1371,14 +1371,6 @@ def _set_binops_check_strict(self: Any, obj: Any) -> bool:
     return isinstance(obj, _set_binop_bases + (self.__class__,))
 
 
-def _set_binops_check_loose(self: Any, obj: Any) -> bool:
-    """Allow anything set-like to participate in set binops."""
-    return (
-        isinstance(obj, _set_binop_bases + (self.__class__,))
-        or util.duck_type_collection(obj) == set
-    )
-
-
 def _set_decorators() -> Dict[str, Callable[[_FN], _FN]]:
     """Tailored instrumentation wrappers for any set-like class."""
 
index 7e2b31a9b5ba6a927f587a6f279c96ee0282a1bc..1aca0c97e259ad9109abbae84f8369fb7ce41f3e 100644 (file)
@@ -40,6 +40,7 @@ from sqlalchemy.testing import assert_raises
 from sqlalchemy.testing import assert_raises_message
 from sqlalchemy.testing import AssertsCompiledSQL
 from sqlalchemy.testing import eq_
+from sqlalchemy.testing import expect_raises
 from sqlalchemy.testing import expect_warnings
 from sqlalchemy.testing import fixtures
 from sqlalchemy.testing import is_
@@ -735,6 +736,63 @@ class SetTest(_CollectionOperations):
 
         assert_raises(TypeError, set, [p1.children])
 
+    def test_special_binops_checks(self):
+        """test for #11349"""
+
+        Parent = self.classes.Parent
+
+        p1 = Parent("P1")
+        p1.children = ["a", "b", "c"]
+        control = {"a", "b", "c"}
+
+        with expect_raises(TypeError):
+            control | ["c", "d"]
+
+        with expect_raises(TypeError):
+            p1.children | ["c", "d"]
+
+        with expect_raises(TypeError):
+            control |= ["c", "d"]
+
+        with expect_raises(TypeError):
+            p1.children |= ["c", "d"]
+
+        with expect_raises(TypeError):
+            control & ["c", "d"]
+
+        with expect_raises(TypeError):
+            p1.children & ["c", "d"]
+
+        with expect_raises(TypeError):
+            control &= ["c", "d"]
+
+        with expect_raises(TypeError):
+            p1.children &= ["c", "d"]
+
+        with expect_raises(TypeError):
+            control ^ ["c", "d"]
+
+        with expect_raises(TypeError):
+            p1.children ^ ["c", "d"]
+
+        with expect_raises(TypeError):
+            control ^= ["c", "d"]
+
+        with expect_raises(TypeError):
+            p1.children ^= ["c", "d"]
+
+        with expect_raises(TypeError):
+            control - ["c", "d"]
+
+        with expect_raises(TypeError):
+            p1.children - ["c", "d"]
+
+        with expect_raises(TypeError):
+            control -= ["c", "d"]
+
+        with expect_raises(TypeError):
+            p1.children -= ["c", "d"]
+
     def test_set_comparisons(self):
         Parent = self.classes.Parent
 
index 3afc79c918ab953e37dd0af5dd635373e1931a8d..d07dadb239b03870c26b81e22478f61dd18d0c2d 100644 (file)
@@ -28,6 +28,7 @@ from sqlalchemy.orm.collections import collection
 from sqlalchemy.testing import assert_raises
 from sqlalchemy.testing import assert_raises_message
 from sqlalchemy.testing import eq_
+from sqlalchemy.testing import expect_raises
 from sqlalchemy.testing import expect_raises_message
 from sqlalchemy.testing import expect_warnings
 from sqlalchemy.testing import fixtures
@@ -866,11 +867,11 @@ class CollectionsTest(OrderedDictFixture, fixtures.ORMTest):
             control |= values
             assert_eq()
 
-            try:
+            with expect_raises(TypeError):
+                control |= [e, creator()]
+
+            with expect_raises(TypeError):
                 direct |= [e, creator()]
-                assert False
-            except TypeError:
-                assert True
 
         addall(creator(), creator())
         direct.clear()
@@ -924,11 +925,11 @@ class CollectionsTest(OrderedDictFixture, fixtures.ORMTest):
             control -= values
             assert_eq()
 
-            try:
+            with expect_raises(TypeError):
+                control -= [e, creator()]
+
+            with expect_raises(TypeError):
                 direct -= [e, creator()]
-                assert False
-            except TypeError:
-                assert True
 
         if hasattr(direct, "intersection_update"):
             zap()
@@ -965,11 +966,11 @@ class CollectionsTest(OrderedDictFixture, fixtures.ORMTest):
             control &= values
             assert_eq()
 
-            try:
+            with expect_raises(TypeError):
+                control &= [e, creator()]
+
+            with expect_raises(TypeError):
                 direct &= [e, creator()]
-                assert False
-            except TypeError:
-                assert True
 
         if hasattr(direct, "symmetric_difference_update"):
             zap()
@@ -1020,11 +1021,11 @@ class CollectionsTest(OrderedDictFixture, fixtures.ORMTest):
             control ^= values
             assert_eq()
 
-            try:
+            with expect_raises(TypeError):
+                control ^= [e, creator()]
+
+            with expect_raises(TypeError):
                 direct ^= [e, creator()]
-                assert False
-            except TypeError:
-                assert True
 
     def _test_set_bulk(self, typecallable, creator=None):
         if creator is None: