From: Mike Bayer Date: Wed, 13 Jul 2022 16:48:37 +0000 (-0400) Subject: implement comparison ops for composites X-Git-Tag: rel_2_0_0b1~181^2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=9d3eaae35b9a3bd74114b350f84281ba9e7fb993;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git implement comparison ops for composites classes mapped by :class:`_orm.composite` now support ordering comparison operations, e.g. ``<``, ``>=``, etc. Change-Id: I44938b9ca2935b2f63c70e930768487ddc6b7669 --- diff --git a/doc/build/changelog/unreleased_20/composite_dataclass.rst b/doc/build/changelog/unreleased_20/composite_dataclass.rst index a7312b0bd4..d5cd70574e 100644 --- a/doc/build/changelog/unreleased_20/composite_dataclass.rst +++ b/doc/build/changelog/unreleased_20/composite_dataclass.rst @@ -6,4 +6,7 @@ ``__composite_values__()`` method no longer needs to be implemented as this method is derived from inspection of the dataclass. + Additionally, classes mapped by :class:`_orm.composite` now support + ordering comparison operations, e.g. ``<``, ``>=``, etc. + See the new documentation at :ref:`mapper_composite` for examples. \ No newline at end of file diff --git a/lib/sqlalchemy/orm/descriptor_props.py b/lib/sqlalchemy/orm/descriptor_props.py index 6d308e141c..52b70b9d47 100644 --- a/lib/sqlalchemy/orm/descriptor_props.py +++ b/lib/sqlalchemy/orm/descriptor_props.py @@ -48,6 +48,7 @@ from .. import schema from .. import sql from .. import util from ..sql import expression +from ..sql import operators from ..sql.elements import BindParameter from ..util.typing import is_pep593 from ..util.typing import typing_get_args @@ -69,6 +70,7 @@ if typing.TYPE_CHECKING: from ..sql._typing import _InfoType from ..sql.elements import ClauseList from ..sql.elements import ColumnElement + from ..sql.operators import OperatorType from ..sql.schema import Column from ..sql.selectable import Select from ..util.typing import _AnnotationScanType @@ -741,21 +743,46 @@ class Composite( return self.prop._comparable_elements def __eq__(self, other: Any) -> ColumnElement[bool]: # type: ignore[override] # noqa: E501 + return self._compare(operators.eq, other) + + def __ne__(self, other: Any) -> ColumnElement[bool]: # type: ignore[override] # noqa: E501 + return self._compare(operators.ne, other) + + def __lt__(self, other: Any) -> ColumnElement[bool]: # type: ignore[override] # noqa: E501 + return self._compare(operators.lt, other) + + def __gt__(self, other: Any) -> ColumnElement[bool]: # type: ignore[override] # noqa: E501 + return self._compare(operators.gt, other) + + def __le__(self, other: Any) -> ColumnElement[bool]: # type: ignore[override] # noqa: E501 + return self._compare(operators.le, other) + + def __ge__(self, other: Any) -> ColumnElement[bool]: # type: ignore[override] # noqa: E501 + return self._compare(operators.ge, other) + + # what might be interesting would be if we create + # an instance of the composite class itself with + # the columns as data members, then use "hybrid style" comparison + # to create these comparisons. then your Point.__eq__() method could + # be where comparison behavior is defined for SQL also. Likely + # not a good choice for default behavior though, not clear how it would + # work w/ dataclasses, etc. also no demand for any of this anyway. + def _compare( + self, operator: OperatorType, other: Any + ) -> ColumnElement[bool]: values: Sequence[Any] if other is None: values = [None] * len(self.prop._comparable_elements) else: values = self.prop._composite_values_from_instance(other) comparisons = [ - a == b for a, b in zip(self.prop._comparable_elements, values) + operator(a, b) + for a, b in zip(self.prop._comparable_elements, values) ] if self._adapt_to_entity: assert self.adapter is not None - comparisons = [self.adapter(x) for x in comparisons] - return sql.and_(*comparisons) - - def __ne__(self, other: Any) -> ColumnElement[bool]: # type: ignore[override] # noqa: E501 - return sql.not_(self.__eq__(other)) + comparisons = [self.adapter(x) for x in comparisons] # type: ignore # noqa: E501 + return sql.and_(*comparisons) # type: ignore def __str__(self) -> str: return str(self.parent.class_.__name__) + "." + self.key diff --git a/test/orm/test_composites.py b/test/orm/test_composites.py index 9f3c52325d..3a789aff76 100644 --- a/test/orm/test_composites.py +++ b/test/orm/test_composites.py @@ -1,4 +1,5 @@ import dataclasses +import operator import sqlalchemy as sa from sqlalchemy import ForeignKey @@ -15,6 +16,7 @@ from sqlalchemy.orm import relationship from sqlalchemy.orm import Session from sqlalchemy.testing import assert_raises_message from sqlalchemy.testing import eq_ +from sqlalchemy.testing import expect_raises_message from sqlalchemy.testing import fixtures from sqlalchemy.testing import is_ from sqlalchemy.testing.fixtures import fixture_session @@ -1290,28 +1292,69 @@ class ComparatorTest(fixtures.MappedTest, testing.AssertsCompiledSQL): }, ) - def test_comparator_behavior_default(self): - self._fixture(False) - self._test_comparator_behavior() - - def test_comparator_behavior_custom(self): - self._fixture(True) - self._test_comparator_behavior() - - def _test_comparator_behavior(self): - Edge, Point = (self.classes.Edge, self.classes.Point) + @testing.combinations(True, False, argnames="custom") + @testing.combinations( + (operator.lt, "<", ">"), + (operator.gt, ">", "<"), + (operator.eq, "=", "="), + (operator.ne, "!=", "!="), + (operator.le, "<=", ">="), + (operator.ge, ">=", "<="), + argnames="operator, fwd_op, rev_op", + ) + def test_comparator_behavior(self, custom, operator, fwd_op, rev_op): + self._fixture(custom) + Edge, Point = self.classes("Edge", "Point") - sess = fixture_session() - e1 = Edge(Point(3, 4), Point(5, 6)) - e2 = Edge(Point(14, 5), Point(2, 7)) - sess.add_all([e1, e2]) - sess.commit() - - assert sess.query(Edge).filter(Edge.start == Point(3, 4)).one() is e1 + self.assert_compile( + select(Edge).filter(operator(Edge.start, Point(3, 4))), + "SELECT edge.id, edge.x1, edge.y1, edge.x2, edge.y2 FROM edge " + f"WHERE edge.x1 {fwd_op} :x1_1 AND edge.y1 {fwd_op} :y1_1", + checkparams={"x1_1": 3, "y1_1": 4}, + ) - assert sess.query(Edge).filter(Edge.start != Point(3, 4)).first() is e2 + self.assert_compile( + select(Edge).filter(~operator(Edge.start, Point(3, 4))), + "SELECT edge.id, edge.x1, edge.y1, edge.x2, edge.y2 FROM edge " + f"WHERE NOT (edge.x1 {fwd_op} :x1_1 AND edge.y1 {fwd_op} :y1_1)", + checkparams={"x1_1": 3, "y1_1": 4}, + ) - eq_(sess.query(Edge).filter(Edge.start == None).all(), []) # noqa + @testing.combinations(True, False, argnames="custom") + @testing.combinations( + (operator.lt, "<", ">"), + (operator.gt, ">", "<"), + (operator.eq, "=", "="), + (operator.ne, "!=", "!="), + (operator.le, "<=", ">="), + (operator.ge, ">=", "<="), + argnames="op, fwd_op, rev_op", + ) + def test_comparator_null(self, custom, op, fwd_op, rev_op): + self._fixture(custom) + Edge, Point = self.classes("Edge", "Point") + + if op is operator.eq: + self.assert_compile( + select(Edge).filter(op(Edge.start, None)), + "SELECT edge.id, edge.x1, edge.y1, edge.x2, edge.y2 FROM edge " + "WHERE edge.x1 IS NULL AND edge.y1 IS NULL", + checkparams={}, + ) + elif op is operator.ne: + self.assert_compile( + select(Edge).filter(op(Edge.start, None)), + "SELECT edge.id, edge.x1, edge.y1, edge.x2, edge.y2 FROM edge " + "WHERE edge.x1 IS NOT NULL AND edge.y1 IS NOT NULL", + checkparams={}, + ) + else: + with expect_raises_message( + sa.exc.ArgumentError, + r"Only '=', '!=', .* operators can be used " + r"with None/True/False", + ): + select(Edge).filter(op(Edge.start, None)) def test_default_comparator_factory(self): self._fixture(False)