From 660e82ec505a05effc9faa2689bc0ddf9dde3c7d Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Sun, 17 Jul 2011 16:44:37 -0400 Subject: [PATCH] - Fixed bug whereby if a mapped class redefined __hash__() or __eq__() to something non-standard, which is a supported use case as SQLA should never consult these, the methods would be consulted if the class was part of a "composite" (i.e. non-single-entity) result set. [ticket:2215] Also in 0.6.9. --- CHANGES | 9 ++++ lib/sqlalchemy/orm/query.py | 25 ++++----- test/orm/test_mapper.py | 103 ++++++++++++++++++++++++++---------- 3 files changed, 98 insertions(+), 39 deletions(-) diff --git a/CHANGES b/CHANGES index cea7db4696..daf1919ef7 100644 --- a/CHANGES +++ b/CHANGES @@ -32,6 +32,15 @@ CHANGES _with_invoke_all_eagers() which selects old/new behavior [ticket:2213] + - Fixed bug whereby if a mapped class + redefined __hash__() or __eq__() to something + non-standard, which is a supported use case + as SQLA should never consult these, + the methods would be consulted if the class + was part of a "composite" (i.e. non-single-entity) + result set. [ticket:2215] + Also in 0.6.9. + - Fixed subtle bug that caused SQL to blow up if: column_property() against subquery + joinedload + LIMIT + order by the column diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index 5570e5a542..a3901868a8 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -1878,22 +1878,18 @@ class Query(object): context.runid = _new_runid() - for ent in self._entities: - if isinstance(ent, _MapperEntity): - filtered = True - break - else: - filtered = False + filter_fns = [ent.filter_fn + for ent in self._entities] + filtered = id in filter_fns single_entity = filtered and len(self._entities) == 1 if filtered: if single_entity: - filter = lambda x: util.unique_list(x, id) + filter_fn = id else: - filter = util.unique_list - else: - filter = None + def filter_fn(row): + return tuple(fn(x) for x, fn in zip(row, filter_fns)) custom_rows = single_entity and \ self._entities[0].mapper.dispatch.append_result @@ -1926,8 +1922,8 @@ class Query(object): rows = [util.NamedTuple([proc(row, None) for proc in process], labels) for row in fetch] - if filter: - rows = filter(rows) + if filtered: + rows = util.unique_list(rows, filter_fn) if context.refresh_state and self._only_load_props \ and context.refresh_state in context.progress: @@ -2665,6 +2661,8 @@ class _MapperEntity(_QueryEntity): self.selectable = from_obj self.adapter = query._get_polymorphic_adapter(self, from_obj) + filter_fn = id + @property def type(self): return self.mapper.class_ @@ -2852,6 +2850,9 @@ class _ColumnEntity(_QueryEntity): def type(self): return self.column.type + def filter_fn(self, item): + return item + def adapt_to_selectable(self, query, sel): c = _ColumnEntity(query, sel.corresponding_column(self.column)) c._label_name = self._label_name diff --git a/test/orm/test_mapper.py b/test/orm/test_mapper.py index ea81847e5f..f749783f21 100644 --- a/test/orm/test_mapper.py +++ b/test/orm/test_mapper.py @@ -2637,6 +2637,19 @@ class RequirementsTest(fixtures.MappedTest): #self.assertRaises(sa.exc.ArgumentError, mapper, NoWeakrefSupport, t2) # end Py2K + class _ValueBase(object): + def __init__(self, value='abc', id=None): + self.id = id + self.value = value + def __nonzero__(self): + return False + def __hash__(self): + return hash(self.value) + def __eq__(self, other): + if isinstance(other, type(self)): + return self.value == other.value + return False + def test_comparison_overrides(self): """Simple tests to ensure users can supply comparison __methods__. @@ -2655,27 +2668,13 @@ class RequirementsTest(fixtures.MappedTest): self.tables.ht1) - # adding these methods directly to each class to avoid decoration - # by the testlib decorators. - class _Base(object): - def __init__(self, value='abc'): - self.value = value - def __nonzero__(self): - return False - def __hash__(self): - return hash(self.value) - def __eq__(self, other): - if isinstance(other, type(self)): - return self.value == other.value - return False - - class H1(_Base): + class H1(self._ValueBase): pass - class H2(_Base): + class H2(self._ValueBase): pass - class H3(_Base): + class H3(self._ValueBase): pass - class H6(_Base): + class H6(self._ValueBase): pass mapper(H1, ht1, properties={ @@ -2692,11 +2691,13 @@ class RequirementsTest(fixtures.MappedTest): mapper(H6, ht6) s = create_session() - for i in range(3): - h1 = H1() - s.add(h1) - - h1.h2s.append(H2()) + s.add_all([ + H1('abc'), + H1('def'), + ]) + h1 = H1('ghi') + s.add(h1) + h1.h2s.append(H2('abc')) h1.h3s.extend([H3(), H3()]) h1.h1s.append(H1()) @@ -2712,11 +2713,11 @@ class RequirementsTest(fixtures.MappedTest): h6.h1b = x = H1() assert x in s - h6.h1b.h2s.append(H2()) + h6.h1b.h2s.append(H2('def')) s.flush() - h1.h2s.extend([H2(), H2()]) + h1.h2s.extend([H2('abc'), H2('def')]) s.flush() h1s = s.query(H1).options(sa.orm.joinedload('h2s')).all() @@ -2726,10 +2727,10 @@ class RequirementsTest(fixtures.MappedTest): {'h2s': []}, {'h2s': []}, {'h2s': (H2, [{'value': 'abc'}, - {'value': 'abc'}, + {'value': 'def'}, {'value': 'abc'}])}, {'h2s': []}, - {'h2s': (H2, [{'value': 'abc'}])}) + {'h2s': (H2, [{'value': 'def'}])}) h1s = s.query(H1).options(sa.orm.joinedload('h3s')).all() @@ -2739,6 +2740,54 @@ class RequirementsTest(fixtures.MappedTest): sa.orm.joinedload_all('h3s.h1s')).all() eq_(len(h1s), 5) + + def test_composite_results(self): + ht2, ht1 = (self.tables.ht2, + self.tables.ht1) + + + class H1(self._ValueBase): + def __init__(self, value, id, h2s): + self.value = value + self.id = id + self.h2s = h2s + class H2(self._ValueBase): + def __init__(self, value, id): + self.value = value + self.id = id + + mapper(H1, ht1, properties={ + 'h2s': relationship(H2, backref='h1'), + }) + mapper(H2, ht2) + s = Session() + s.add_all([ + H1('abc', 1, h2s=[ + H2('abc', id=1), + H2('def', id=2), + H2('def', id=3), + ]), + H1('def', 2, h2s=[ + H2('abc', id=4), + H2('abc', id=5), + H2('def', id=6), + ]), + ]) + s.commit() + eq_( + [(h1.value, h1.id, h2.value, h2.id) + for h1, h2 in + s.query(H1, H2).join(H1.h2s).order_by(H1.id, H2.id)], + [ + ('abc', 1, 'abc', 1), + ('abc', 1, 'def', 2), + ('abc', 1, 'def', 3), + ('def', 2, 'abc', 4), + ('def', 2, 'abc', 5), + ('def', 2, 'def', 6), + ] + ) + def test_nonzero_len_recursion(self): ht1 = self.tables.ht1 -- 2.39.5