From dfe81ee73db80a3ec91732946c3054370f6454ee Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Sun, 17 Jul 2011 16:44:07 -0400 Subject: [PATCH] - Fixed bug whereby if a mapped class redefined __hash__() or __eq__() to something non-standard, which is a supported use case as SQLA should never consult these, the methods would be consulted if the class was part of a "composite" (i.e. non-single-entity) result set. [ticket:2215] --- CHANGES | 8 +++ lib/sqlalchemy/orm/query.py | 20 +++++--- lib/sqlalchemy/util.py | 13 +++-- test/orm/test_mapper.py | 99 +++++++++++++++++++++++++++---------- 4 files changed, 103 insertions(+), 37 deletions(-) diff --git a/CHANGES b/CHANGES index a995c42703..973cbdf792 100644 --- a/CHANGES +++ b/CHANGES @@ -6,6 +6,14 @@ CHANGES 0.6.9 ===== - orm + - Fixed bug whereby if a mapped class + redefined __hash__() or __eq__() to something + non-standard, which is a supported use case + as SQLA should never consult these, + the methods would be consulted if the class + was part of a "composite" (i.e. non-single-entity) + result set. [ticket:2215] + - Fixed subtle bug that caused SQL to blow up if: column_property() against subquery + joinedload + LIMIT + order by the column diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index f83787490d..956136941c 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -1790,16 +1790,17 @@ class Query(object): context.runid = _new_runid() - filtered = bool(list(self._mapper_entities)) + filter_fns = [ent.filter_fn + for ent in self._entities] + filtered = id in filter_fns single_entity = filtered and len(self._entities) == 1 if filtered: if single_entity: - filter = lambda x: util.unique_list(x, util.IdentitySet) + filter_fn = id else: - filter = util.unique_list - else: - filter = None + def filter_fn(row): + return tuple(fn(x) for x, fn in zip(row, filter_fns)) custom_rows = single_entity and \ 'append_result' in self._entities[0].extension @@ -1832,8 +1833,8 @@ class Query(object): rows = [util.NamedTuple([proc(row, None) for proc in process], labels) for row in fetch] - if filter: - rows = filter(rows) + if filtered: + rows = util.unique_list(rows, filter_fn) if context.refresh_state and self._only_load_props \ and context.refresh_state in context.progress: @@ -2602,6 +2603,8 @@ class _MapperEntity(_QueryEntity): self.selectable = from_obj self.adapter = query._get_polymorphic_adapter(self, from_obj) + filter_fn = id + @property def type(self): return self.mapper.class_ @@ -2783,6 +2786,9 @@ class _ColumnEntity(_QueryEntity): def type(self): return self.column.type + def filter_fn(self, item): + return item + def adapt_to_selectable(self, query, sel): c = _ColumnEntity(query, sel.corresponding_column(self.column)) c.entity_zero = self.entity_zero diff --git a/lib/sqlalchemy/util.py b/lib/sqlalchemy/util.py index 6b6f14be0e..68cd185dd5 100644 --- a/lib/sqlalchemy/util.py +++ b/lib/sqlalchemy/util.py @@ -1254,9 +1254,16 @@ column_dict = dict ordered_column_set = OrderedSet populate_column_dict = PopulateDict -def unique_list(seq, compare_with=set): - seen = compare_with() - return [x for x in seq if x not in seen and not seen.add(x)] +def unique_list(seq, hashfunc=None): + seen = {} + if not hashfunc: + return [x for x in seq + if x not in seen + and not seen.__setitem__(x, True)] + else: + return [x for x in seq + if hashfunc(x) not in seen + and not seen.__setitem__(hashfunc(x), True)] class UniqueAppender(object): """Appends items to a collection ensuring uniqueness. diff --git a/test/orm/test_mapper.py b/test/orm/test_mapper.py index d3d8568e25..be3903b7df 100644 --- a/test/orm/test_mapper.py +++ b/test/orm/test_mapper.py @@ -2581,6 +2581,19 @@ class RequirementsTest(_base.MappedTest): #self.assertRaises(sa.exc.ArgumentError, mapper, NoWeakrefSupport, t2) # end Py2K + class _ValueBase(object): + def __init__(self, value='abc', id=None): + self.id = id + self.value = value + def __nonzero__(self): + return False + def __hash__(self): + return hash(self.value) + def __eq__(self, other): + if isinstance(other, type(self)): + return self.value == other.value + return False + @testing.resolve_artifact_names def test_comparison_overrides(self): """Simple tests to ensure users can supply comparison __methods__. @@ -2592,27 +2605,13 @@ class RequirementsTest(_base.MappedTest): test run. """ - # adding these methods directly to each class to avoid decoration - # by the testlib decorators. - class _Base(object): - def __init__(self, value='abc'): - self.value = value - def __nonzero__(self): - return False - def __hash__(self): - return hash(self.value) - def __eq__(self, other): - if isinstance(other, type(self)): - return self.value == other.value - return False - - class H1(_Base): + class H1(self._ValueBase): pass - class H2(_Base): + class H2(self._ValueBase): pass - class H3(_Base): + class H3(self._ValueBase): pass - class H6(_Base): + class H6(self._ValueBase): pass mapper(H1, ht1, properties={ @@ -2629,11 +2628,13 @@ class RequirementsTest(_base.MappedTest): mapper(H6, ht6) s = create_session() - for i in range(3): - h1 = H1() - s.add(h1) - - h1.h2s.append(H2()) + s.add_all([ + H1('abc'), + H1('def'), + ]) + h1 = H1('ghi') + s.add(h1) + h1.h2s.append(H2('abc')) h1.h3s.extend([H3(), H3()]) h1.h1s.append(H1()) @@ -2649,11 +2650,11 @@ class RequirementsTest(_base.MappedTest): h6.h1b = x = H1() assert x in s - h6.h1b.h2s.append(H2()) + h6.h1b.h2s.append(H2('def')) s.flush() - h1.h2s.extend([H2(), H2()]) + h1.h2s.extend([H2('abc'), H2('def')]) s.flush() h1s = s.query(H1).options(sa.orm.joinedload('h2s')).all() @@ -2663,10 +2664,10 @@ class RequirementsTest(_base.MappedTest): {'h2s': []}, {'h2s': []}, {'h2s': (H2, [{'value': 'abc'}, - {'value': 'abc'}, + {'value': 'def'}, {'value': 'abc'}])}, {'h2s': []}, - {'h2s': (H2, [{'value': 'abc'}])}) + {'h2s': (H2, [{'value': 'def'}])}) h1s = s.query(H1).options(sa.orm.joinedload('h3s')).all() @@ -2676,6 +2677,50 @@ class RequirementsTest(_base.MappedTest): sa.orm.joinedload_all('h3s.h1s')).all() eq_(len(h1s), 5) + @testing.resolve_artifact_names + def test_composite_results(self): + class H1(self._ValueBase): + def __init__(self, value, id, h2s): + self.value = value + self.id = id + self.h2s = h2s + class H2(self._ValueBase): + def __init__(self, value, id): + self.value = value + self.id = id + + mapper(H1, ht1, properties={ + 'h2s': relationship(H2, backref='h1'), + }) + mapper(H2, ht2) + s = Session() + s.add_all([ + H1('abc', 1, h2s=[ + H2('abc', id=1), + H2('def', id=2), + H2('def', id=3), + ]), + H1('def', 2, h2s=[ + H2('abc', id=4), + H2('abc', id=5), + H2('def', id=6), + ]), + ]) + s.commit() + eq_( + [(h1.value, h1.id, h2.value, h2.id) + for h1, h2 in + s.query(H1, H2).join(H1.h2s).order_by(H1.id, H2.id)], + [ + ('abc', 1, 'abc', 1), + ('abc', 1, 'def', 2), + ('abc', 1, 'def', 3), + ('def', 2, 'abc', 4), + ('def', 2, 'abc', 5), + ('def', 2, 'def', 6), + ] + ) + @testing.resolve_artifact_names def test_nonzero_len_recursion(self): class H1(object): -- 2.47.3