util, select
from test.lib.schema import Table, Column
from sqlalchemy.orm import mapper, relationship, backref, \
- class_mapper, \
+ class_mapper, CompositeProperty, \
validates, aliased
from sqlalchemy.orm import attributes, \
composite, relationship, \
deferred=True)
})
self._test_roundtrip()
+
+class ComparatorTest(fixtures.MappedTest):
+ @classmethod
+ def define_tables(cls, metadata):
+ Table('edge', metadata,
+ Column('id', Integer, primary_key=True,
+ test_needs_autoincrement=True),
+ Column('x1', Integer),
+ Column('y1', Integer),
+ Column('x2', Integer),
+ Column('y2', Integer),
+ )
+
+ @classmethod
+ def setup_mappers(cls):
+ class Point(cls.Comparable):
+ def __init__(self, x, y):
+ self.x = x
+ self.y = y
+ def __composite_values__(self):
+ return [self.x, self.y]
+ def __eq__(self, other):
+ return isinstance(other, Point) and \
+ other.x == self.x and \
+ other.y == self.y
+ def __ne__(self, other):
+ return not isinstance(other, Point) or \
+ not self.__eq__(other)
+
+ class Edge(cls.Comparable):
+ def __init__(self, start, end):
+ self.start = start
+ self.end = end
+
+ def __eq__(self, other):
+ return isinstance(other, Edge) and \
+ other.id == self.id
+
+ def _fixture(self, custom):
+ edge, Edge, Point = (self.tables.edge,
+ self.classes.Edge,
+ self.classes.Point)
+
+ if custom:
+ class CustomComparator(sa.orm.CompositeProperty.Comparator):
+ def near(self, other, d):
+ clauses = self.__clause_element__().clauses
+ diff_x = clauses[0] - other.x
+ diff_y = clauses[1] - other.y
+ return diff_x * diff_x + diff_y * diff_y <= d * d
+
+ mapper(Edge, edge, properties={
+ 'start': sa.orm.composite(Point, edge.c.x1, edge.c.y1,
+ comparator_factory=CustomComparator),
+ 'end': sa.orm.composite(Point, edge.c.x2, edge.c.y2)
+ })
+ else:
+ mapper(Edge, edge, properties={
+ 'start': sa.orm.composite(Point, edge.c.x1, edge.c.y1),
+ 'end': sa.orm.composite(Point, edge.c.x2, edge.c.y2)
+ })
+
+ def test_comparator_behavior_default(self):
+ self._fixture(False)
+ self._test_comparator_behavior()
+
+ def test_comparator_behavior_custom(self):
+ self._fixture(True)
+ self._test_comparator_behavior()
+
+ def _test_comparator_behavior(self):
+ Edge, Point = (self.classes.Edge,
+ self.classes.Point)
+
+ sess = Session()
+ e1 = Edge(Point(3, 4), Point(5, 6))
+ e2 = Edge(Point(14, 5), Point(2, 7))
+ sess.add_all([e1, e2])
+ sess.commit()
+
+ assert sess.query(Edge).\
+ filter(Edge.start==Point(3, 4)).one() is \
+ e1
+
+ assert sess.query(Edge).\
+ filter(Edge.start!=Point(3, 4)).first() is \
+ e2
+
+ eq_(
+ sess.query(Edge).filter(Edge.start==None).all(),
+ []
+ )
+
+ def test_default_comparator_factory(self):
+ self._fixture(False)
+ Edge = self.classes.Edge
+ start_prop = Edge.start.property
+
+ assert start_prop.comparator_factory is CompositeProperty.Comparator
+
+ def test_custom_comparator_factory(self):
+ self._fixture(True)
+ Edge, Point = (self.classes.Edge,
+ self.classes.Point)
+
+ edge_1, edge_2 = Edge(Point(0, 0), Point(3, 5)), \
+ Edge(Point(0, 1), Point(3, 5))
+
+ sess = Session()
+ sess.add_all([edge_1, edge_2])
+ sess.commit()
+
+ near_edges = sess.query(Edge).filter(
+ Edge.start.near(Point(1, 1), 1)
+ ).all()
+
+ assert edge_1 not in near_edges
+ assert edge_2 in near_edges
+
+ near_edges = sess.query(Edge).filter(
+ Edge.start.near(Point(0, 1), 1)
+ ).all()
+
+ assert edge_1 in near_edges and edge_2 in near_edges
+
+