]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- Reinstated "comparator_factory" argument to
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 15 Oct 2011 18:31:02 +0000 (14:31 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 15 Oct 2011 18:31:02 +0000 (14:31 -0400)
composite(), removed when 0.7 was released.
[ticket:2248]

CHANGES
lib/sqlalchemy/orm/descriptor_props.py
test/orm/test_composites.py

diff --git a/CHANGES b/CHANGES
index bcb433d7862872c6ff28cde68f52a84058364d76..cddd86566f52d64e89de1f60d56037e3b5e2d971 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -85,6 +85,10 @@ CHANGES
     deferred=True option failed due to missing
     import [ticket:2253]
 
+  - Reinstated "comparator_factory" argument to 
+    composite(), removed when 0.7 was released.  
+    [ticket:2248]
+
   - Fixed bug in query.join() which would occur
     in a complex multiple-overlapping path scenario,
     where the same table could be joined to
index cb31fadac0f77cbd4b494267945ceb6f5ec68e24..594705a8af9d78c86dda51b628c60766e8108f1a 100644 (file)
@@ -79,6 +79,8 @@ class CompositeProperty(DescriptorProperty):
         self.active_history = kwargs.get('active_history', False)
         self.deferred = kwargs.get('deferred', False)
         self.group = kwargs.get('group', None)
+        self.comparator_factory = kwargs.pop('comparator_factory',
+                                            self.__class__.Comparator)
         util.set_creation_order(self)
         self._create_descriptor()
 
@@ -257,11 +259,11 @@ class CompositeProperty(DescriptorProperty):
             )
 
     def _comparator_factory(self, mapper):
-        return CompositeProperty.Comparator(self)
+        return self.comparator_factory(self)
 
     class Comparator(PropComparator):
         def __init__(self, prop, adapter=None):
-            self.prop = prop
+            self.prop = self.property = prop
             self.adapter = adapter
 
         def __clause_element__(self):
index 0d3cc20d6397758cca55170b4a8281519086ddb9..0c16e57a18977b9ee4a73dabad617ea159acced8 100644 (file)
@@ -5,7 +5,7 @@ from sqlalchemy import MetaData, Integer, String, ForeignKey, func, \
     util, select
 from test.lib.schema import Table, Column
 from sqlalchemy.orm import mapper, relationship, backref, \
-    class_mapper,  \
+    class_mapper, CompositeProperty, \
     validates, aliased
 from sqlalchemy.orm import attributes, \
     composite, relationship, \
@@ -634,3 +634,129 @@ class ConfigurationTest(fixtures.MappedTest):
                                             deferred=True)
         })
         self._test_roundtrip()
+
+class ComparatorTest(fixtures.MappedTest):
+    @classmethod
+    def define_tables(cls, metadata):
+        Table('edge', metadata,
+            Column('id', Integer, primary_key=True, 
+                                test_needs_autoincrement=True),
+            Column('x1', Integer),
+            Column('y1', Integer),
+            Column('x2', Integer),
+            Column('y2', Integer),
+        )
+
+    @classmethod
+    def setup_mappers(cls):
+        class Point(cls.Comparable):
+            def __init__(self, x, y):
+                self.x = x
+                self.y = y
+            def __composite_values__(self):
+                return [self.x, self.y]
+            def __eq__(self, other):
+                return isinstance(other, Point) and \
+                        other.x == self.x and \
+                        other.y == self.y
+            def __ne__(self, other):
+                return not isinstance(other, Point) or \
+                    not self.__eq__(other)
+
+        class Edge(cls.Comparable):
+            def __init__(self, start, end):
+                self.start = start
+                self.end = end
+
+            def __eq__(self, other):
+                return isinstance(other, Edge) and \
+                       other.id == self.id
+
+    def _fixture(self, custom):
+        edge, Edge, Point = (self.tables.edge,
+                                self.classes.Edge,
+                                self.classes.Point)
+
+        if custom:
+            class CustomComparator(sa.orm.CompositeProperty.Comparator):
+                def near(self, other, d):
+                    clauses = self.__clause_element__().clauses
+                    diff_x = clauses[0] - other.x
+                    diff_y = clauses[1] - other.y
+                    return diff_x * diff_x + diff_y * diff_y <= d * d
+
+            mapper(Edge, edge, properties={
+                'start': sa.orm.composite(Point, edge.c.x1, edge.c.y1, 
+                                          comparator_factory=CustomComparator),
+                'end': sa.orm.composite(Point, edge.c.x2, edge.c.y2)
+            })
+        else:
+            mapper(Edge, edge, properties={
+                'start': sa.orm.composite(Point, edge.c.x1, edge.c.y1),
+                'end': sa.orm.composite(Point, edge.c.x2, edge.c.y2)
+            })
+
+    def test_comparator_behavior_default(self):
+        self._fixture(False)
+        self._test_comparator_behavior()
+
+    def test_comparator_behavior_custom(self):
+        self._fixture(True)
+        self._test_comparator_behavior()
+
+    def _test_comparator_behavior(self):
+        Edge, Point = (self.classes.Edge,
+                                self.classes.Point)
+
+        sess = Session()
+        e1 = Edge(Point(3, 4), Point(5, 6))
+        e2 = Edge(Point(14, 5), Point(2, 7))
+        sess.add_all([e1, e2])
+        sess.commit()
+
+        assert sess.query(Edge).\
+                    filter(Edge.start==Point(3, 4)).one() is \
+                    e1
+
+        assert sess.query(Edge).\
+                    filter(Edge.start!=Point(3, 4)).first() is \
+                    e2
+
+        eq_(
+            sess.query(Edge).filter(Edge.start==None).all(), 
+            []
+        )
+
+    def test_default_comparator_factory(self):
+        self._fixture(False)
+        Edge = self.classes.Edge
+        start_prop = Edge.start.property
+
+        assert start_prop.comparator_factory is CompositeProperty.Comparator
+
+    def test_custom_comparator_factory(self):
+        self._fixture(True)
+        Edge, Point = (self.classes.Edge,
+                                self.classes.Point)
+
+        edge_1, edge_2 = Edge(Point(0, 0), Point(3, 5)), \
+                        Edge(Point(0, 1), Point(3, 5))
+
+        sess = Session()
+        sess.add_all([edge_1, edge_2])
+        sess.commit()
+
+        near_edges = sess.query(Edge).filter(
+            Edge.start.near(Point(1, 1), 1)
+        ).all()
+
+        assert edge_1 not in near_edges
+        assert edge_2 in near_edges
+
+        near_edges = sess.query(Edge).filter(
+            Edge.start.near(Point(0, 1), 1)
+        ).all()
+
+        assert edge_1 in near_edges and edge_2 in near_edges
+
+