from .. import sql
from .. import util
from ..sql import expression
+from ..sql import operators
from ..sql.elements import BindParameter
from ..util.typing import is_pep593
from ..util.typing import typing_get_args
from ..sql._typing import _InfoType
from ..sql.elements import ClauseList
from ..sql.elements import ColumnElement
+ from ..sql.operators import OperatorType
from ..sql.schema import Column
from ..sql.selectable import Select
from ..util.typing import _AnnotationScanType
return self.prop._comparable_elements
def __eq__(self, other: Any) -> ColumnElement[bool]: # type: ignore[override] # noqa: E501
+ return self._compare(operators.eq, other)
+
+ def __ne__(self, other: Any) -> ColumnElement[bool]: # type: ignore[override] # noqa: E501
+ return self._compare(operators.ne, other)
+
+ def __lt__(self, other: Any) -> ColumnElement[bool]: # type: ignore[override] # noqa: E501
+ return self._compare(operators.lt, other)
+
+ def __gt__(self, other: Any) -> ColumnElement[bool]: # type: ignore[override] # noqa: E501
+ return self._compare(operators.gt, other)
+
+ def __le__(self, other: Any) -> ColumnElement[bool]: # type: ignore[override] # noqa: E501
+ return self._compare(operators.le, other)
+
+ def __ge__(self, other: Any) -> ColumnElement[bool]: # type: ignore[override] # noqa: E501
+ return self._compare(operators.ge, other)
+
+ # what might be interesting would be if we create
+ # an instance of the composite class itself with
+ # the columns as data members, then use "hybrid style" comparison
+ # to create these comparisons. then your Point.__eq__() method could
+ # be where comparison behavior is defined for SQL also. Likely
+ # not a good choice for default behavior though, not clear how it would
+ # work w/ dataclasses, etc. also no demand for any of this anyway.
+ def _compare(
+ self, operator: OperatorType, other: Any
+ ) -> ColumnElement[bool]:
values: Sequence[Any]
if other is None:
values = [None] * len(self.prop._comparable_elements)
else:
values = self.prop._composite_values_from_instance(other)
comparisons = [
- a == b for a, b in zip(self.prop._comparable_elements, values)
+ operator(a, b)
+ for a, b in zip(self.prop._comparable_elements, values)
]
if self._adapt_to_entity:
assert self.adapter is not None
- comparisons = [self.adapter(x) for x in comparisons]
- return sql.and_(*comparisons)
-
- def __ne__(self, other: Any) -> ColumnElement[bool]: # type: ignore[override] # noqa: E501
- return sql.not_(self.__eq__(other))
+ comparisons = [self.adapter(x) for x in comparisons] # type: ignore # noqa: E501
+ return sql.and_(*comparisons) # type: ignore
def __str__(self) -> str:
return str(self.parent.class_.__name__) + "." + self.key
import dataclasses
+import operator
import sqlalchemy as sa
from sqlalchemy import ForeignKey
from sqlalchemy.orm import Session
from sqlalchemy.testing import assert_raises_message
from sqlalchemy.testing import eq_
+from sqlalchemy.testing import expect_raises_message
from sqlalchemy.testing import fixtures
from sqlalchemy.testing import is_
from sqlalchemy.testing.fixtures import fixture_session
},
)
- def test_comparator_behavior_default(self):
- self._fixture(False)
- self._test_comparator_behavior()
-
- def test_comparator_behavior_custom(self):
- self._fixture(True)
- self._test_comparator_behavior()
-
- def _test_comparator_behavior(self):
- Edge, Point = (self.classes.Edge, self.classes.Point)
+ @testing.combinations(True, False, argnames="custom")
+ @testing.combinations(
+ (operator.lt, "<", ">"),
+ (operator.gt, ">", "<"),
+ (operator.eq, "=", "="),
+ (operator.ne, "!=", "!="),
+ (operator.le, "<=", ">="),
+ (operator.ge, ">=", "<="),
+ argnames="operator, fwd_op, rev_op",
+ )
+ def test_comparator_behavior(self, custom, operator, fwd_op, rev_op):
+ self._fixture(custom)
+ Edge, Point = self.classes("Edge", "Point")
- sess = fixture_session()
- e1 = Edge(Point(3, 4), Point(5, 6))
- e2 = Edge(Point(14, 5), Point(2, 7))
- sess.add_all([e1, e2])
- sess.commit()
-
- assert sess.query(Edge).filter(Edge.start == Point(3, 4)).one() is e1
+ self.assert_compile(
+ select(Edge).filter(operator(Edge.start, Point(3, 4))),
+ "SELECT edge.id, edge.x1, edge.y1, edge.x2, edge.y2 FROM edge "
+ f"WHERE edge.x1 {fwd_op} :x1_1 AND edge.y1 {fwd_op} :y1_1",
+ checkparams={"x1_1": 3, "y1_1": 4},
+ )
- assert sess.query(Edge).filter(Edge.start != Point(3, 4)).first() is e2
+ self.assert_compile(
+ select(Edge).filter(~operator(Edge.start, Point(3, 4))),
+ "SELECT edge.id, edge.x1, edge.y1, edge.x2, edge.y2 FROM edge "
+ f"WHERE NOT (edge.x1 {fwd_op} :x1_1 AND edge.y1 {fwd_op} :y1_1)",
+ checkparams={"x1_1": 3, "y1_1": 4},
+ )
- eq_(sess.query(Edge).filter(Edge.start == None).all(), []) # noqa
+ @testing.combinations(True, False, argnames="custom")
+ @testing.combinations(
+ (operator.lt, "<", ">"),
+ (operator.gt, ">", "<"),
+ (operator.eq, "=", "="),
+ (operator.ne, "!=", "!="),
+ (operator.le, "<=", ">="),
+ (operator.ge, ">=", "<="),
+ argnames="op, fwd_op, rev_op",
+ )
+ def test_comparator_null(self, custom, op, fwd_op, rev_op):
+ self._fixture(custom)
+ Edge, Point = self.classes("Edge", "Point")
+
+ if op is operator.eq:
+ self.assert_compile(
+ select(Edge).filter(op(Edge.start, None)),
+ "SELECT edge.id, edge.x1, edge.y1, edge.x2, edge.y2 FROM edge "
+ "WHERE edge.x1 IS NULL AND edge.y1 IS NULL",
+ checkparams={},
+ )
+ elif op is operator.ne:
+ self.assert_compile(
+ select(Edge).filter(op(Edge.start, None)),
+ "SELECT edge.id, edge.x1, edge.y1, edge.x2, edge.y2 FROM edge "
+ "WHERE edge.x1 IS NOT NULL AND edge.y1 IS NOT NULL",
+ checkparams={},
+ )
+ else:
+ with expect_raises_message(
+ sa.exc.ArgumentError,
+ r"Only '=', '!=', .* operators can be used "
+ r"with None/True/False",
+ ):
+ select(Edge).filter(op(Edge.start, None))
def test_default_comparator_factory(self):
self._fixture(False)