From: Mike Bayer Date: Wed, 3 Dec 2008 17:28:36 +0000 (+0000) Subject: - Two fixes to help prevent out-of-band columns from X-Git-Tag: rel_0_5_0~146 X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=0410eae36b36dc8ea7e747c4b81c7ec9de5f2da4;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git - Two fixes to help prevent out-of-band columns from being rendered in polymorphic_union inheritance scenarios (which then causes extra tables to be rendered in the FROM clause causing cartesian products): - improvements to "column adaption" for a->b->c inheritance situations to better locate columns that are related to one another via multiple levels of indirection, rather than rendering the non-adapted column. - the "polymorphic discriminator" column is only rendered for the actual mapper being queried against. The column won't be "pulled in" from a subclass or superclass mapper since it's not needed. --- diff --git a/CHANGES b/CHANGES index 562a5baa8f..f1245c7bfa 100644 --- a/CHANGES +++ b/CHANGES @@ -61,6 +61,23 @@ CHANGES - made Session.merge cascades not trigger autoflush. Fixes merged instances getting prematurely inserted with missing values. + + - Two fixes to help prevent out-of-band columns from + being rendered in polymorphic_union inheritance + scenarios (which then causes extra tables to be + rendered in the FROM clause causing cartesian + products): + - improvements to "column adaption" for + a->b->c inheritance situations to better + locate columns that are related to one + another via multiple levels of indirection, + rather than rendering the non-adapted + column. + - the "polymorphic discriminator" column is + only rendered for the actual mapper being + queried against. The column won't be + "pulled in" from a subclass or superclass + mapper since it's not needed. - sql - Fixed the import weirdness in sqlalchemy.sql diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index b48297a645..48c2f9e27f 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -612,7 +612,13 @@ class Mapper(object): # right set if hasattr(self, '_cols_by_table') and col.table in self._cols_by_table and col not in self._cols_by_table[col.table]: self._cols_by_table[col.table].add(col) - + + # if this ColumnProperty represents the "polymorphic discriminator" + # column, mark it. We'll need this when rendering columns + # in SELECT statements. + if not hasattr(prop, '_is_polymorphic_discriminator'): + prop._is_polymorphic_discriminator = (col is self.polymorphic_on or prop.columns[0] is self.polymorphic_on) + self.columns[key] = col for col in prop.columns: for col in col.proxy_set: @@ -860,20 +866,27 @@ class Mapper(object): else: return mappers, self._selectable_from_mappers(mappers) - @property - def _default_polymorphic_properties(self): - return util.unique_list( - chain(*[list(mapper.iterate_properties) for mapper in [self] + self._with_polymorphic_mappers]) - ) - def _iterate_polymorphic_properties(self, mappers=None): + """Return an iterator of MapperProperty objects which will render into a SELECT.""" + if mappers is None: - return iter(self._default_polymorphic_properties) + mappers = self._with_polymorphic_mappers + + if not mappers: + for c in self.iterate_properties: + yield c else: - return iter(util.unique_list( + # in the polymorphic case, filter out discriminator columns + # from other mappers, as these are sometimes dependent on that + # mapper's polymorphic selectable (which we don't want rendered) + for c in util.unique_list( chain(*[list(mapper.iterate_properties) for mapper in [self] + mappers]) - )) - + ): + if getattr(c, '_is_polymorphic_discriminator', False) and \ + (not self.polymorphic_on or c.columns[0] is not self.polymorphic_on): + continue + yield c + @property def properties(self): raise NotImplementedError("Public collection of MapperProperty objects is " diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py index d5f2417c27..1bd4dd857d 100644 --- a/lib/sqlalchemy/sql/util.py +++ b/lib/sqlalchemy/sql/util.py @@ -439,12 +439,12 @@ class ClauseAdapter(visitors.ReplacingCloningVisitor): self.exclude = exclude self.equivalents = equivalents or {} - def _corresponding_column(self, col, require_embedded): + def _corresponding_column(self, col, require_embedded, _seen=util.EMPTY_SET): newcol = self.selectable.corresponding_column(col, require_embedded=require_embedded) - if not newcol and col in self.equivalents: + if not newcol and col in self.equivalents and col not in _seen: for equiv in self.equivalents[col]: - newcol = self.selectable.corresponding_column(equiv, require_embedded=require_embedded) + newcol = self._corresponding_column(equiv, require_embedded=require_embedded, _seen=_seen.union([col])) if newcol: return newcol return newcol diff --git a/test/orm/inheritance/concrete.py b/test/orm/inheritance/concrete.py index b8d799b604..e6277f3e91 100644 --- a/test/orm/inheritance/concrete.py +++ b/test/orm/inheritance/concrete.py @@ -7,12 +7,18 @@ from sqlalchemy.orm import attributes class ConcreteTest(ORMTest): def define_tables(self, metadata): - global managers_table, engineers_table, hackers_table, companies + global managers_table, engineers_table, hackers_table, companies, employees_table companies = Table('companies', metadata, Column('id', Integer, primary_key=True), Column('name', String(50))) + employees_table = Table('employees', metadata, + Column('employee_id', Integer, primary_key=True), + Column('name', String(50)), + Column('company_id', Integer, ForeignKey('companies.id')) + ) + managers_table = Table('managers', metadata, Column('employee_id', Integer, primary_key=True), Column('name', String(50)), @@ -80,7 +86,7 @@ class ConcreteTest(ORMTest): session.expire(manager, ['manager_data']) self.assertEquals(manager.manager_data, "knows how to manage things") - def test_multi_level(self): + def test_multi_level_no_base(self): class Employee(object): def __init__(self, name): self.name = name @@ -157,6 +163,92 @@ class ConcreteTest(ORMTest): assert set([repr(x) for x in session.query(Employee).all()]) == set(["Engineer Jerry knows how to program", "Manager Tom knows how to manage things", "Hacker Kurt 'Badass' knows how to hack"]) assert set([repr(x) for x in session.query(Manager).all()]) == set(["Manager Tom knows how to manage things"]) assert set([repr(x) for x in session.query(Engineer).all()]) == set(["Engineer Jerry knows how to program", "Hacker Kurt 'Badass' knows how to hack"]) + assert set([repr(x) for x in session.query(Hacker).all()]) == set(["Hacker Kurt 'Badass' knows how to hack"]) + + def test_multi_level_with_base(self): + class Employee(object): + def __init__(self, name): + self.name = name + def __repr__(self): + return self.__class__.__name__ + " " + self.name + + class Manager(Employee): + def __init__(self, name, manager_data): + self.name = name + self.manager_data = manager_data + def __repr__(self): + return self.__class__.__name__ + " " + self.name + " " + self.manager_data + + class Engineer(Employee): + def __init__(self, name, engineer_info): + self.name = name + self.engineer_info = engineer_info + def __repr__(self): + return self.__class__.__name__ + " " + self.name + " " + self.engineer_info + + class Hacker(Engineer): + def __init__(self, name, nickname, engineer_info): + self.name = name + self.nickname = nickname + self.engineer_info = engineer_info + def __repr__(self): + return self.__class__.__name__ + " " + self.name + " '" + \ + self.nickname + "' " + self.engineer_info + + pjoin = polymorphic_union({ + 'employee':employees_table, + 'manager': managers_table, + 'engineer': engineers_table, + 'hacker': hackers_table + }, 'type', 'pjoin') + + pjoin2 = polymorphic_union({ + 'engineer': engineers_table, + 'hacker': hackers_table + }, 'type', 'pjoin2') + + employee_mapper = mapper(Employee, employees_table, with_polymorphic=('*', pjoin), polymorphic_on=pjoin.c.type) + manager_mapper = mapper(Manager, managers_table, + inherits=employee_mapper, concrete=True, + polymorphic_identity='manager') + engineer_mapper = mapper(Engineer, engineers_table, + with_polymorphic=('*', pjoin2), + polymorphic_on=pjoin2.c.type, + inherits=employee_mapper, concrete=True, + polymorphic_identity='engineer') + hacker_mapper = mapper(Hacker, hackers_table, + inherits=engineer_mapper, + concrete=True, polymorphic_identity='hacker') + + session = create_session() + tom = Manager('Tom', 'knows how to manage things') + jerry = Engineer('Jerry', 'knows how to program') + hacker = Hacker('Kurt', 'Badass', 'knows how to hack') + session.add_all((tom, jerry, hacker)) + session.flush() + + # ensure "readonly" on save logic didn't pollute the expired_attributes + # collection + assert 'nickname' not in attributes.instance_state(jerry).expired_attributes + assert 'name' not in attributes.instance_state(jerry).expired_attributes + assert 'name' not in attributes.instance_state(hacker).expired_attributes + assert 'nickname' not in attributes.instance_state(hacker).expired_attributes + def go(): + self.assertEquals(jerry.name, "Jerry") + self.assertEquals(hacker.nickname, "Badass") + self.assert_sql_count(testing.db, go, 0) + + session.clear() + + # check that we aren't getting a cartesian product in the raw SQL. + # this requires that Engineer's polymorphic discriminator is not rendered + # in the statement which is only against Employee's "pjoin" + assert len(testing.db.execute(session.query(Employee).with_labels().statement).fetchall()) == 3 + + assert set([repr(x) for x in session.query(Employee).all()]) == set(["Engineer Jerry knows how to program", "Manager Tom knows how to manage things", "Hacker Kurt 'Badass' knows how to hack"]) + assert set([repr(x) for x in session.query(Manager).all()]) == set(["Manager Tom knows how to manage things"]) + assert set([repr(x) for x in session.query(Engineer).all()]) == set(["Engineer Jerry knows how to program", "Hacker Kurt 'Badass' knows how to hack"]) + assert set([repr(x) for x in session.query(Hacker).all()]) == set(["Hacker Kurt 'Badass' knows how to hack"]) def test_relation(self): class Employee(object): diff --git a/test/sql/generative.py b/test/sql/generative.py index f6b849e8a3..4edf334f66 100644 --- a/test/sql/generative.py +++ b/test/sql/generative.py @@ -458,6 +458,32 @@ class ClauseAdapterTest(TestBase, AssertsCompiledSQL): assert str(e) == "a_1.id = a.xxx_id" + def test_recursive_equivalents(self): + m = MetaData() + a = Table('a', m, Column('x', Integer), Column('y', Integer)) + b = Table('b', m, Column('x', Integer), Column('y', Integer)) + c = Table('c', m, Column('x', Integer), Column('y', Integer)) + + # force a recursion overflow, by linking a.c.x<->c.c.x, and + # asking for a nonexistent col. corresponding_column should prevent + # endless depth. + adapt = sql_util.ClauseAdapter( b, equivalents= {a.c.x: set([ c.c.x]), c.c.x:set([a.c.x])}) + assert adapt._corresponding_column(a.c.x, False) is None + + def test_multilevel_equivalents(self): + m = MetaData() + a = Table('a', m, Column('x', Integer), Column('y', Integer)) + b = Table('b', m, Column('x', Integer), Column('y', Integer)) + c = Table('c', m, Column('x', Integer), Column('y', Integer)) + + alias = select([a]).select_from(a.join(b, a.c.x==b.c.x)).alias() + + # two levels of indirection from c.x->b.x->a.x, requires recursive + # corresponding_column call + adapt = sql_util.ClauseAdapter(alias, equivalents= {b.c.x: set([ a.c.x]), c.c.x:set([b.c.x])}) + assert adapt._corresponding_column(a.c.x, False) is alias.c.x + assert adapt._corresponding_column(c.c.x, False) is alias.c.x + def test_join_to_alias(self): metadata = MetaData() a = Table('a', metadata,