]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- Fixed bug whereby if a mapped class
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 17 Jul 2011 20:44:37 +0000 (16:44 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 17 Jul 2011 20:44:37 +0000 (16:44 -0400)
redefined __hash__() or __eq__() to something
non-standard, which is a supported use case
as SQLA should never consult these,
the methods would be consulted if the class
was part of a "composite" (i.e. non-single-entity)
result set.  [ticket:2215]
Also in 0.6.9.

CHANGES
lib/sqlalchemy/orm/query.py
test/orm/test_mapper.py

diff --git a/CHANGES b/CHANGES
index cea7db4696501f3081f5f69d6b5ff6a96879d8d3..daf1919ef7f0ee780d6807b90e4487d801ebb01f 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -32,6 +32,15 @@ CHANGES
     _with_invoke_all_eagers()
     which selects old/new behavior [ticket:2213]
 
+  - Fixed bug whereby if a mapped class
+    redefined __hash__() or __eq__() to something
+    non-standard, which is a supported use case
+    as SQLA should never consult these,
+    the methods would be consulted if the class
+    was part of a "composite" (i.e. non-single-entity)
+    result set.  [ticket:2215]
+    Also in 0.6.9.
+
   - Fixed subtle bug that caused SQL to blow
     up if: column_property() against subquery +
     joinedload + LIMIT + order by the column
index 5570e5a542827da1e052d42519c235c6954f6901..a3901868a83841f1da44741c7a1971844fdd6e8e 100644 (file)
@@ -1878,22 +1878,18 @@ class Query(object):
 
         context.runid = _new_runid()
 
-        for ent in self._entities:
-            if isinstance(ent, _MapperEntity):
-                filtered = True
-                break
-        else:
-            filtered = False
+        filter_fns = [ent.filter_fn
+                    for ent in self._entities]
+        filtered = id in filter_fns
 
         single_entity = filtered and len(self._entities) == 1
 
         if filtered:
             if single_entity:
-                filter = lambda x: util.unique_list(x, id)
+                filter_fn = id
             else:
-                filter = util.unique_list
-        else:
-            filter = None
+                def filter_fn(row):
+                    return tuple(fn(x) for x, fn in zip(row, filter_fns))
 
         custom_rows = single_entity and \
                         self._entities[0].mapper.dispatch.append_result
@@ -1926,8 +1922,8 @@ class Query(object):
                 rows = [util.NamedTuple([proc(row, None) for proc in process],
                                         labels) for row in fetch]
 
-            if filter:
-                rows = filter(rows)
+            if filtered:
+                rows = util.unique_list(rows, filter_fn)
 
             if context.refresh_state and self._only_load_props \
                         and context.refresh_state in context.progress:
@@ -2665,6 +2661,8 @@ class _MapperEntity(_QueryEntity):
             self.selectable = from_obj
             self.adapter = query._get_polymorphic_adapter(self, from_obj)
 
+    filter_fn = id
+
     @property
     def type(self):
         return self.mapper.class_
@@ -2852,6 +2850,9 @@ class _ColumnEntity(_QueryEntity):
     def type(self):
         return self.column.type
 
+    def filter_fn(self, item):
+        return item
+
     def adapt_to_selectable(self, query, sel):
         c = _ColumnEntity(query, sel.corresponding_column(self.column))
         c._label_name = self._label_name 
index ea81847e5fd7fbb7b4e3f66a1af33b10d75c9e91..f749783f21aef4fcd218855b4e4bfc608aa5c51e 100644 (file)
@@ -2637,6 +2637,19 @@ class RequirementsTest(fixtures.MappedTest):
         #self.assertRaises(sa.exc.ArgumentError, mapper, NoWeakrefSupport, t2)
     # end Py2K
 
+    class _ValueBase(object):
+        def __init__(self, value='abc', id=None):
+            self.id = id
+            self.value = value
+        def __nonzero__(self):
+            return False
+        def __hash__(self):
+            return hash(self.value)
+        def __eq__(self, other):
+            if isinstance(other, type(self)):
+                return self.value == other.value
+            return False
+
     def test_comparison_overrides(self):
         """Simple tests to ensure users can supply comparison __methods__.
 
@@ -2655,27 +2668,13 @@ class RequirementsTest(fixtures.MappedTest):
                                 self.tables.ht1)
 
 
-        # adding these methods directly to each class to avoid decoration
-        # by the testlib decorators.
-        class _Base(object):
-            def __init__(self, value='abc'):
-                self.value = value
-            def __nonzero__(self):
-                return False
-            def __hash__(self):
-                return hash(self.value)
-            def __eq__(self, other):
-                if isinstance(other, type(self)):
-                    return self.value == other.value
-                return False
-
-        class H1(_Base):
+        class H1(self._ValueBase):
             pass
-        class H2(_Base):
+        class H2(self._ValueBase):
             pass
-        class H3(_Base):
+        class H3(self._ValueBase):
             pass
-        class H6(_Base):
+        class H6(self._ValueBase):
             pass
 
         mapper(H1, ht1, properties={
@@ -2692,11 +2691,13 @@ class RequirementsTest(fixtures.MappedTest):
         mapper(H6, ht6)
 
         s = create_session()
-        for i in range(3):
-            h1 = H1()
-            s.add(h1)
-
-        h1.h2s.append(H2())
+        s.add_all([
+            H1('abc'),
+            H1('def'),
+        ])
+        h1 = H1('ghi')
+        s.add(h1)
+        h1.h2s.append(H2('abc'))
         h1.h3s.extend([H3(), H3()])
         h1.h1s.append(H1())
 
@@ -2712,11 +2713,11 @@ class RequirementsTest(fixtures.MappedTest):
         h6.h1b = x = H1()
         assert x in s
 
-        h6.h1b.h2s.append(H2())
+        h6.h1b.h2s.append(H2('def'))
 
         s.flush()
 
-        h1.h2s.extend([H2(), H2()])
+        h1.h2s.extend([H2('abc'), H2('def')])
         s.flush()
 
         h1s = s.query(H1).options(sa.orm.joinedload('h2s')).all()
@@ -2726,10 +2727,10 @@ class RequirementsTest(fixtures.MappedTest):
                                      {'h2s': []},
                                      {'h2s': []},
                                      {'h2s': (H2, [{'value': 'abc'},
-                                                   {'value': 'abc'},
+                                                   {'value': 'def'},
                                                    {'value': 'abc'}])},
                                      {'h2s': []},
-                                     {'h2s': (H2, [{'value': 'abc'}])})
+                                     {'h2s': (H2, [{'value': 'def'}])})
 
         h1s = s.query(H1).options(sa.orm.joinedload('h3s')).all()
 
@@ -2739,6 +2740,54 @@ class RequirementsTest(fixtures.MappedTest):
                                   sa.orm.joinedload_all('h3s.h1s')).all()
         eq_(len(h1s), 5)
 
+
+    def test_composite_results(self):
+        ht2, ht1 = (self.tables.ht2,
+                        self.tables.ht1)
+
+
+        class H1(self._ValueBase):
+            def __init__(self, value, id, h2s):
+                self.value = value
+                self.id = id
+                self.h2s = h2s
+        class H2(self._ValueBase):
+            def __init__(self, value, id):
+                self.value = value
+                self.id = id
+
+        mapper(H1, ht1, properties={
+            'h2s': relationship(H2, backref='h1'),
+        })
+        mapper(H2, ht2)
+        s = Session()
+        s.add_all([
+            H1('abc', 1, h2s=[
+                H2('abc', id=1),
+                H2('def', id=2),
+                H2('def', id=3),
+            ]),
+            H1('def', 2, h2s=[
+                H2('abc', id=4),
+                H2('abc', id=5),
+                H2('def', id=6),
+            ]),
+        ])
+        s.commit()
+        eq_(
+            [(h1.value, h1.id, h2.value, h2.id) 
+            for h1, h2 in 
+            s.query(H1, H2).join(H1.h2s).order_by(H1.id, H2.id)],
+            [
+                ('abc', 1, 'abc', 1),
+                ('abc', 1, 'def', 2),
+                ('abc', 1, 'def', 3),
+                ('def', 2, 'abc', 4),
+                ('def', 2, 'abc', 5),
+                ('def', 2, 'def', 6),
+            ]
+        )
+
     def test_nonzero_len_recursion(self):
         ht1 = self.tables.ht1