]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
implement comparison ops for composites
authorMike Bayer <mike_mp@zzzcomputing.com>
Wed, 13 Jul 2022 16:48:37 +0000 (12:48 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Wed, 13 Jul 2022 16:48:37 +0000 (12:48 -0400)
classes mapped by :class:`_orm.composite` now support
ordering comparison operations, e.g. ``<``, ``>=``, etc.

Change-Id: I44938b9ca2935b2f63c70e930768487ddc6b7669

doc/build/changelog/unreleased_20/composite_dataclass.rst
lib/sqlalchemy/orm/descriptor_props.py
test/orm/test_composites.py

index a7312b0bd47322617bb475fbe719dc81096c9bc3..d5cd70574eb02e3cb2932947ee2a725977042060 100644 (file)
@@ -6,4 +6,7 @@
     ``__composite_values__()`` method no longer needs to be implemented as this
     method is derived from inspection of the dataclass.
 
+    Additionally, classes mapped by :class:`_orm.composite` now support
+    ordering comparison operations, e.g. ``<``, ``>=``, etc.
+
     See the new documentation at :ref:`mapper_composite` for examples.
\ No newline at end of file
index 6d308e141ced95c9dd695083341a37df8ce32226..52b70b9d47a2aaded3060c896c8a518edf0fdf6b 100644 (file)
@@ -48,6 +48,7 @@ from .. import schema
 from .. import sql
 from .. import util
 from ..sql import expression
+from ..sql import operators
 from ..sql.elements import BindParameter
 from ..util.typing import is_pep593
 from ..util.typing import typing_get_args
@@ -69,6 +70,7 @@ if typing.TYPE_CHECKING:
     from ..sql._typing import _InfoType
     from ..sql.elements import ClauseList
     from ..sql.elements import ColumnElement
+    from ..sql.operators import OperatorType
     from ..sql.schema import Column
     from ..sql.selectable import Select
     from ..util.typing import _AnnotationScanType
@@ -741,21 +743,46 @@ class Composite(
                 return self.prop._comparable_elements
 
         def __eq__(self, other: Any) -> ColumnElement[bool]:  # type: ignore[override]  # noqa: E501
+            return self._compare(operators.eq, other)
+
+        def __ne__(self, other: Any) -> ColumnElement[bool]:  # type: ignore[override]  # noqa: E501
+            return self._compare(operators.ne, other)
+
+        def __lt__(self, other: Any) -> ColumnElement[bool]:  # type: ignore[override]  # noqa: E501
+            return self._compare(operators.lt, other)
+
+        def __gt__(self, other: Any) -> ColumnElement[bool]:  # type: ignore[override]  # noqa: E501
+            return self._compare(operators.gt, other)
+
+        def __le__(self, other: Any) -> ColumnElement[bool]:  # type: ignore[override]  # noqa: E501
+            return self._compare(operators.le, other)
+
+        def __ge__(self, other: Any) -> ColumnElement[bool]:  # type: ignore[override]  # noqa: E501
+            return self._compare(operators.ge, other)
+
+        # what might be interesting would be if we create
+        # an instance of the composite class itself with
+        # the columns as data members, then use "hybrid style" comparison
+        # to create these comparisons.  then your Point.__eq__() method could
+        # be where comparison behavior is defined for SQL also.   Likely
+        # not a good choice for default behavior though, not clear how it would
+        # work w/ dataclasses, etc.  also no demand for any of this anyway.
+        def _compare(
+            self, operator: OperatorType, other: Any
+        ) -> ColumnElement[bool]:
             values: Sequence[Any]
             if other is None:
                 values = [None] * len(self.prop._comparable_elements)
             else:
                 values = self.prop._composite_values_from_instance(other)
             comparisons = [
-                a == b for a, b in zip(self.prop._comparable_elements, values)
+                operator(a, b)
+                for a, b in zip(self.prop._comparable_elements, values)
             ]
             if self._adapt_to_entity:
                 assert self.adapter is not None
-                comparisons = [self.adapter(x) for x in comparisons]
-            return sql.and_(*comparisons)
-
-        def __ne__(self, other: Any) -> ColumnElement[bool]:  # type: ignore[override]  # noqa: E501
-            return sql.not_(self.__eq__(other))
+                comparisons = [self.adapter(x) for x in comparisons]  # type: ignore  # noqa: E501
+            return sql.and_(*comparisons)  # type: ignore
 
     def __str__(self) -> str:
         return str(self.parent.class_.__name__) + "." + self.key
index 9f3c52325d5ed275debd7a4d762bd973e4fb8cbf..3a789aff769069bcb0343d86e886f403ea8ef7eb 100644 (file)
@@ -1,4 +1,5 @@
 import dataclasses
+import operator
 
 import sqlalchemy as sa
 from sqlalchemy import ForeignKey
@@ -15,6 +16,7 @@ from sqlalchemy.orm import relationship
 from sqlalchemy.orm import Session
 from sqlalchemy.testing import assert_raises_message
 from sqlalchemy.testing import eq_
+from sqlalchemy.testing import expect_raises_message
 from sqlalchemy.testing import fixtures
 from sqlalchemy.testing import is_
 from sqlalchemy.testing.fixtures import fixture_session
@@ -1290,28 +1292,69 @@ class ComparatorTest(fixtures.MappedTest, testing.AssertsCompiledSQL):
                 },
             )
 
-    def test_comparator_behavior_default(self):
-        self._fixture(False)
-        self._test_comparator_behavior()
-
-    def test_comparator_behavior_custom(self):
-        self._fixture(True)
-        self._test_comparator_behavior()
-
-    def _test_comparator_behavior(self):
-        Edge, Point = (self.classes.Edge, self.classes.Point)
+    @testing.combinations(True, False, argnames="custom")
+    @testing.combinations(
+        (operator.lt, "<", ">"),
+        (operator.gt, ">", "<"),
+        (operator.eq, "=", "="),
+        (operator.ne, "!=", "!="),
+        (operator.le, "<=", ">="),
+        (operator.ge, ">=", "<="),
+        argnames="operator, fwd_op, rev_op",
+    )
+    def test_comparator_behavior(self, custom, operator, fwd_op, rev_op):
+        self._fixture(custom)
+        Edge, Point = self.classes("Edge", "Point")
 
-        sess = fixture_session()
-        e1 = Edge(Point(3, 4), Point(5, 6))
-        e2 = Edge(Point(14, 5), Point(2, 7))
-        sess.add_all([e1, e2])
-        sess.commit()
-
-        assert sess.query(Edge).filter(Edge.start == Point(3, 4)).one() is e1
+        self.assert_compile(
+            select(Edge).filter(operator(Edge.start, Point(3, 4))),
+            "SELECT edge.id, edge.x1, edge.y1, edge.x2, edge.y2 FROM edge "
+            f"WHERE edge.x1 {fwd_op} :x1_1 AND edge.y1 {fwd_op} :y1_1",
+            checkparams={"x1_1": 3, "y1_1": 4},
+        )
 
-        assert sess.query(Edge).filter(Edge.start != Point(3, 4)).first() is e2
+        self.assert_compile(
+            select(Edge).filter(~operator(Edge.start, Point(3, 4))),
+            "SELECT edge.id, edge.x1, edge.y1, edge.x2, edge.y2 FROM edge "
+            f"WHERE NOT (edge.x1 {fwd_op} :x1_1 AND edge.y1 {fwd_op} :y1_1)",
+            checkparams={"x1_1": 3, "y1_1": 4},
+        )
 
-        eq_(sess.query(Edge).filter(Edge.start == None).all(), [])  # noqa
+    @testing.combinations(True, False, argnames="custom")
+    @testing.combinations(
+        (operator.lt, "<", ">"),
+        (operator.gt, ">", "<"),
+        (operator.eq, "=", "="),
+        (operator.ne, "!=", "!="),
+        (operator.le, "<=", ">="),
+        (operator.ge, ">=", "<="),
+        argnames="op, fwd_op, rev_op",
+    )
+    def test_comparator_null(self, custom, op, fwd_op, rev_op):
+        self._fixture(custom)
+        Edge, Point = self.classes("Edge", "Point")
+
+        if op is operator.eq:
+            self.assert_compile(
+                select(Edge).filter(op(Edge.start, None)),
+                "SELECT edge.id, edge.x1, edge.y1, edge.x2, edge.y2 FROM edge "
+                "WHERE edge.x1 IS NULL AND edge.y1 IS NULL",
+                checkparams={},
+            )
+        elif op is operator.ne:
+            self.assert_compile(
+                select(Edge).filter(op(Edge.start, None)),
+                "SELECT edge.id, edge.x1, edge.y1, edge.x2, edge.y2 FROM edge "
+                "WHERE edge.x1 IS NOT NULL AND edge.y1 IS NOT NULL",
+                checkparams={},
+            )
+        else:
+            with expect_raises_message(
+                sa.exc.ArgumentError,
+                r"Only '=', '!=', .* operators can be used "
+                r"with None/True/False",
+            ):
+                select(Edge).filter(op(Edge.start, None))
 
     def test_default_comparator_factory(self):
         self._fixture(False)