]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- Fixed bug whereby if a mapped class
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 17 Jul 2011 20:44:07 +0000 (16:44 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 17 Jul 2011 20:44:07 +0000 (16:44 -0400)
redefined __hash__() or __eq__() to something
non-standard, which is a supported use case
as SQLA should never consult these,
the methods would be consulted if the class
was part of a "composite" (i.e. non-single-entity)
result set.  [ticket:2215]

CHANGES
lib/sqlalchemy/orm/query.py
lib/sqlalchemy/util.py
test/orm/test_mapper.py

diff --git a/CHANGES b/CHANGES
index a995c4270386ad2432bb2f208012fb591372b7fb..973cbdf7923785113cf6663fdeb2e4d163b6de35 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -6,6 +6,14 @@ CHANGES
 0.6.9
 =====
 - orm
+  - Fixed bug whereby if a mapped class
+    redefined __hash__() or __eq__() to something
+    non-standard, which is a supported use case
+    as SQLA should never consult these,
+    the methods would be consulted if the class
+    was part of a "composite" (i.e. non-single-entity)
+    result set.  [ticket:2215]
+
   - Fixed subtle bug that caused SQL to blow
     up if: column_property() against subquery +
     joinedload + LIMIT + order by the column
index f83787490d300ca2b4e5e0f9ddb99cc29829c9c6..956136941c4282aa580565bcea3ed26a0689cda7 100644 (file)
@@ -1790,16 +1790,17 @@ class Query(object):
 
         context.runid = _new_runid()
 
-        filtered = bool(list(self._mapper_entities))
+        filter_fns = [ent.filter_fn
+                    for ent in self._entities]
+        filtered = id in filter_fns
         single_entity = filtered and len(self._entities) == 1
 
         if filtered:
             if single_entity:
-                filter = lambda x: util.unique_list(x, util.IdentitySet)
+                filter_fn = id
             else:
-                filter = util.unique_list
-        else:
-            filter = None
+                def filter_fn(row):
+                    return tuple(fn(x) for x, fn in zip(row, filter_fns))
 
         custom_rows = single_entity and \
                         'append_result' in self._entities[0].extension
@@ -1832,8 +1833,8 @@ class Query(object):
                 rows = [util.NamedTuple([proc(row, None) for proc in process],
                                         labels) for row in fetch]
 
-            if filter:
-                rows = filter(rows)
+            if filtered:
+                rows = util.unique_list(rows, filter_fn)
 
             if context.refresh_state and self._only_load_props \
                         and context.refresh_state in context.progress:
@@ -2602,6 +2603,8 @@ class _MapperEntity(_QueryEntity):
             self.selectable = from_obj
             self.adapter = query._get_polymorphic_adapter(self, from_obj)
 
+    filter_fn = id
+
     @property
     def type(self):
         return self.mapper.class_
@@ -2783,6 +2786,9 @@ class _ColumnEntity(_QueryEntity):
     def type(self):
         return self.column.type
 
+    def filter_fn(self, item):
+        return item
+
     def adapt_to_selectable(self, query, sel):
         c = _ColumnEntity(query, sel.corresponding_column(self.column))
         c.entity_zero = self.entity_zero
index 6b6f14be0ec5092581d9dc3fd4f52b71259fed17..68cd185dd59d0f246dda9145f3930040ad4ec111 100644 (file)
@@ -1254,9 +1254,16 @@ column_dict = dict
 ordered_column_set = OrderedSet
 populate_column_dict = PopulateDict
 
-def unique_list(seq, compare_with=set):
-    seen = compare_with()
-    return [x for x in seq if x not in seen and not seen.add(x)]
+def unique_list(seq, hashfunc=None):
+    seen = {}
+    if not hashfunc:
+        return [x for x in seq 
+                if x not in seen 
+                and not seen.__setitem__(x, True)]
+    else:
+        return [x for x in seq 
+                if hashfunc(x) not in seen 
+                and not seen.__setitem__(hashfunc(x), True)]
 
 class UniqueAppender(object):
     """Appends items to a collection ensuring uniqueness.
index d3d8568e2500daae3f14ecd69cbc43e3d9b220e8..be3903b7dfa60fa2829a0bb9b513f95fc21e078d 100644 (file)
@@ -2581,6 +2581,19 @@ class RequirementsTest(_base.MappedTest):
         #self.assertRaises(sa.exc.ArgumentError, mapper, NoWeakrefSupport, t2)
     # end Py2K
 
+    class _ValueBase(object):
+        def __init__(self, value='abc', id=None):
+            self.id = id
+            self.value = value
+        def __nonzero__(self):
+            return False
+        def __hash__(self):
+            return hash(self.value)
+        def __eq__(self, other):
+            if isinstance(other, type(self)):
+                return self.value == other.value
+            return False
+
     @testing.resolve_artifact_names
     def test_comparison_overrides(self):
         """Simple tests to ensure users can supply comparison __methods__.
@@ -2592,27 +2605,13 @@ class RequirementsTest(_base.MappedTest):
         test run.
         """
 
-        # adding these methods directly to each class to avoid decoration
-        # by the testlib decorators.
-        class _Base(object):
-            def __init__(self, value='abc'):
-                self.value = value
-            def __nonzero__(self):
-                return False
-            def __hash__(self):
-                return hash(self.value)
-            def __eq__(self, other):
-                if isinstance(other, type(self)):
-                    return self.value == other.value
-                return False
-
-        class H1(_Base):
+        class H1(self._ValueBase):
             pass
-        class H2(_Base):
+        class H2(self._ValueBase):
             pass
-        class H3(_Base):
+        class H3(self._ValueBase):
             pass
-        class H6(_Base):
+        class H6(self._ValueBase):
             pass
 
         mapper(H1, ht1, properties={
@@ -2629,11 +2628,13 @@ class RequirementsTest(_base.MappedTest):
         mapper(H6, ht6)
 
         s = create_session()
-        for i in range(3):
-            h1 = H1()
-            s.add(h1)
-
-        h1.h2s.append(H2())
+        s.add_all([
+            H1('abc'),
+            H1('def'),
+        ])
+        h1 = H1('ghi')
+        s.add(h1)
+        h1.h2s.append(H2('abc'))
         h1.h3s.extend([H3(), H3()])
         h1.h1s.append(H1())
 
@@ -2649,11 +2650,11 @@ class RequirementsTest(_base.MappedTest):
         h6.h1b = x = H1()
         assert x in s
 
-        h6.h1b.h2s.append(H2())
+        h6.h1b.h2s.append(H2('def'))
 
         s.flush()
 
-        h1.h2s.extend([H2(), H2()])
+        h1.h2s.extend([H2('abc'), H2('def')])
         s.flush()
 
         h1s = s.query(H1).options(sa.orm.joinedload('h2s')).all()
@@ -2663,10 +2664,10 @@ class RequirementsTest(_base.MappedTest):
                                      {'h2s': []},
                                      {'h2s': []},
                                      {'h2s': (H2, [{'value': 'abc'},
-                                                   {'value': 'abc'},
+                                                   {'value': 'def'},
                                                    {'value': 'abc'}])},
                                      {'h2s': []},
-                                     {'h2s': (H2, [{'value': 'abc'}])})
+                                     {'h2s': (H2, [{'value': 'def'}])})
 
         h1s = s.query(H1).options(sa.orm.joinedload('h3s')).all()
 
@@ -2676,6 +2677,50 @@ class RequirementsTest(_base.MappedTest):
                                   sa.orm.joinedload_all('h3s.h1s')).all()
         eq_(len(h1s), 5)
 
+    @testing.resolve_artifact_names
+    def test_composite_results(self):
+        class H1(self._ValueBase):
+            def __init__(self, value, id, h2s):
+                self.value = value
+                self.id = id
+                self.h2s = h2s
+        class H2(self._ValueBase):
+            def __init__(self, value, id):
+                self.value = value
+                self.id = id
+
+        mapper(H1, ht1, properties={
+            'h2s': relationship(H2, backref='h1'),
+        })
+        mapper(H2, ht2)
+        s = Session()
+        s.add_all([
+            H1('abc', 1, h2s=[
+                H2('abc', id=1),
+                H2('def', id=2),
+                H2('def', id=3),
+            ]),
+            H1('def', 2, h2s=[
+                H2('abc', id=4),
+                H2('abc', id=5),
+                H2('def', id=6),
+            ]),
+        ])
+        s.commit()
+        eq_(
+            [(h1.value, h1.id, h2.value, h2.id) 
+            for h1, h2 in 
+            s.query(H1, H2).join(H1.h2s).order_by(H1.id, H2.id)],
+            [
+                ('abc', 1, 'abc', 1),
+                ('abc', 1, 'def', 2),
+                ('abc', 1, 'def', 3),
+                ('def', 2, 'abc', 4),
+                ('def', 2, 'abc', 5),
+                ('def', 2, 'def', 6),
+            ]
+        )
+
     @testing.resolve_artifact_names
     def test_nonzero_len_recursion(self):
         class H1(object):