From 80c68c0e22e2b45b3eaffcb7485d6a9f5eb02ba4 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Sat, 15 Oct 2011 14:31:02 -0400 Subject: [PATCH] - Reinstated "comparator_factory" argument to composite(), removed when 0.7 was released. [ticket:2248] --- CHANGES | 4 + lib/sqlalchemy/orm/descriptor_props.py | 6 +- test/orm/test_composites.py | 128 ++++++++++++++++++++++++- 3 files changed, 135 insertions(+), 3 deletions(-) diff --git a/CHANGES b/CHANGES index bcb433d786..cddd86566f 100644 --- a/CHANGES +++ b/CHANGES @@ -85,6 +85,10 @@ CHANGES deferred=True option failed due to missing import [ticket:2253] + - Reinstated "comparator_factory" argument to + composite(), removed when 0.7 was released. + [ticket:2248] + - Fixed bug in query.join() which would occur in a complex multiple-overlapping path scenario, where the same table could be joined to diff --git a/lib/sqlalchemy/orm/descriptor_props.py b/lib/sqlalchemy/orm/descriptor_props.py index cb31fadac0..594705a8af 100644 --- a/lib/sqlalchemy/orm/descriptor_props.py +++ b/lib/sqlalchemy/orm/descriptor_props.py @@ -79,6 +79,8 @@ class CompositeProperty(DescriptorProperty): self.active_history = kwargs.get('active_history', False) self.deferred = kwargs.get('deferred', False) self.group = kwargs.get('group', None) + self.comparator_factory = kwargs.pop('comparator_factory', + self.__class__.Comparator) util.set_creation_order(self) self._create_descriptor() @@ -257,11 +259,11 @@ class CompositeProperty(DescriptorProperty): ) def _comparator_factory(self, mapper): - return CompositeProperty.Comparator(self) + return self.comparator_factory(self) class Comparator(PropComparator): def __init__(self, prop, adapter=None): - self.prop = prop + self.prop = self.property = prop self.adapter = adapter def __clause_element__(self): diff --git a/test/orm/test_composites.py b/test/orm/test_composites.py index 0d3cc20d63..0c16e57a18 100644 --- a/test/orm/test_composites.py +++ b/test/orm/test_composites.py @@ -5,7 +5,7 @@ from sqlalchemy import MetaData, Integer, String, ForeignKey, func, \ util, select from test.lib.schema import Table, Column from sqlalchemy.orm import mapper, relationship, backref, \ - class_mapper, \ + class_mapper, CompositeProperty, \ validates, aliased from sqlalchemy.orm import attributes, \ composite, relationship, \ @@ -634,3 +634,129 @@ class ConfigurationTest(fixtures.MappedTest): deferred=True) }) self._test_roundtrip() + +class ComparatorTest(fixtures.MappedTest): + @classmethod + def define_tables(cls, metadata): + Table('edge', metadata, + Column('id', Integer, primary_key=True, + test_needs_autoincrement=True), + Column('x1', Integer), + Column('y1', Integer), + Column('x2', Integer), + Column('y2', Integer), + ) + + @classmethod + def setup_mappers(cls): + class Point(cls.Comparable): + def __init__(self, x, y): + self.x = x + self.y = y + def __composite_values__(self): + return [self.x, self.y] + def __eq__(self, other): + return isinstance(other, Point) and \ + other.x == self.x and \ + other.y == self.y + def __ne__(self, other): + return not isinstance(other, Point) or \ + not self.__eq__(other) + + class Edge(cls.Comparable): + def __init__(self, start, end): + self.start = start + self.end = end + + def __eq__(self, other): + return isinstance(other, Edge) and \ + other.id == self.id + + def _fixture(self, custom): + edge, Edge, Point = (self.tables.edge, + self.classes.Edge, + self.classes.Point) + + if custom: + class CustomComparator(sa.orm.CompositeProperty.Comparator): + def near(self, other, d): + clauses = self.__clause_element__().clauses + diff_x = clauses[0] - other.x + diff_y = clauses[1] - other.y + return diff_x * diff_x + diff_y * diff_y <= d * d + + mapper(Edge, edge, properties={ + 'start': sa.orm.composite(Point, edge.c.x1, edge.c.y1, + comparator_factory=CustomComparator), + 'end': sa.orm.composite(Point, edge.c.x2, edge.c.y2) + }) + else: + mapper(Edge, edge, properties={ + 'start': sa.orm.composite(Point, edge.c.x1, edge.c.y1), + 'end': sa.orm.composite(Point, edge.c.x2, edge.c.y2) + }) + + def test_comparator_behavior_default(self): + self._fixture(False) + self._test_comparator_behavior() + + def test_comparator_behavior_custom(self): + self._fixture(True) + self._test_comparator_behavior() + + def _test_comparator_behavior(self): + Edge, Point = (self.classes.Edge, + self.classes.Point) + + sess = Session() + e1 = Edge(Point(3, 4), Point(5, 6)) + e2 = Edge(Point(14, 5), Point(2, 7)) + sess.add_all([e1, e2]) + sess.commit() + + assert sess.query(Edge).\ + filter(Edge.start==Point(3, 4)).one() is \ + e1 + + assert sess.query(Edge).\ + filter(Edge.start!=Point(3, 4)).first() is \ + e2 + + eq_( + sess.query(Edge).filter(Edge.start==None).all(), + [] + ) + + def test_default_comparator_factory(self): + self._fixture(False) + Edge = self.classes.Edge + start_prop = Edge.start.property + + assert start_prop.comparator_factory is CompositeProperty.Comparator + + def test_custom_comparator_factory(self): + self._fixture(True) + Edge, Point = (self.classes.Edge, + self.classes.Point) + + edge_1, edge_2 = Edge(Point(0, 0), Point(3, 5)), \ + Edge(Point(0, 1), Point(3, 5)) + + sess = Session() + sess.add_all([edge_1, edge_2]) + sess.commit() + + near_edges = sess.query(Edge).filter( + Edge.start.near(Point(1, 1), 1) + ).all() + + assert edge_1 not in near_edges + assert edge_2 in near_edges + + near_edges = sess.query(Edge).filter( + Edge.start.near(Point(0, 1), 1) + ).all() + + assert edge_1 in near_edges and edge_2 in near_edges + + -- 2.39.5