From 9eced72c035a9e0424ca8a77c9f657783a5a94dd Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Tue, 15 Jan 2008 17:59:27 +0000 Subject: [PATCH] finally, a really straightforward reduce() method which reduces cols to the minimal set for every test case I can come up with, and now replaces all the cruft in Mapper._compile_pks() as well as Join.__init_primary_key(). mappers can now handle aliased selects and figure out the correct PKs pretty well [ticket:933] --- CHANGES | 9 ++- lib/sqlalchemy/orm/mapper.py | 64 ++++------------- lib/sqlalchemy/orm/session.py | 2 +- lib/sqlalchemy/sql/expression.py | 53 +++----------- lib/sqlalchemy/sql/util.py | 36 ++++++---- test/orm/inheritance/query.py | 11 ++- test/orm/mapper.py | 2 +- test/sql/selectable.py | 119 +++++++++++++++++++++++++++++++ 8 files changed, 186 insertions(+), 110 deletions(-) diff --git a/CHANGES b/CHANGES index a940248185..b555603b7b 100644 --- a/CHANGES +++ b/CHANGES @@ -30,7 +30,14 @@ CHANGES - general improvements to the behavior of join() in conjunction with polymorphic mappers, i.e. joining from/to polymorphic mappers and properly applying - aliases + aliases. + + - fixed/improved behavior when a mapper determines the + natural "primary key" of a mapped join, it will more + effectively reduce columns which are equivalent via + foreign key relation. This affects how many arguments + need to be sent to query.get(), among other things. + [ticket:933] - fixed bug in polymorphic inheritance which made it difficult to set a working "order_by" on a polymorphic diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index 61f5a65791..07075efd09 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -418,6 +418,7 @@ class Mapper(object): all_cols = util.Set(chain(*[c2 for c2 in [col.proxy_set for col in [c for c in self._columntoproperty]]])) pk_cols = util.Set([c for c in all_cols if c.primary_key]) + # identify primary key columns which are also mapped by this mapper. for t in util.Set(self.tables + [self.mapped_table]): self._all_tables.add(t) if t.primary_key and pk_cols.issuperset(t.primary_key): @@ -425,6 +426,7 @@ class Mapper(object): self._pks_by_table[t] = util.OrderedSet(t.primary_key).intersection(pk_cols) self._cols_by_table[t] = util.OrderedSet(t.c).intersection(all_cols) + # if explicit PK argument sent, add those columns to the primary key mappings if self.primary_key_argument: for k in self.primary_key_argument: if k.table not in self._pks_by_table: @@ -432,58 +434,22 @@ class Mapper(object): self._pks_by_table[k.table].add(k) if self.mapped_table not in self._pks_by_table or len(self._pks_by_table[self.mapped_table]) == 0: - raise exceptions.ArgumentError("Could not assemble any primary key columns for mapped table '%s'" % (self.mapped_table.name)) + raise exceptions.ArgumentError("Mapper %s could not assemble any primary key columns for mapped table '%s'" % (self, self.mapped_table.description)) if self.inherits is not None and not self.concrete and not self.primary_key_argument: + # if inheriting, the "primary key" for this mapper is that of the inheriting (unless concrete or explicit) self.primary_key = self.inherits.primary_key self._get_clause = self.inherits._get_clause else: - # create the "primary_key" for this mapper. this will flatten "equivalent" primary key columns - # into one column, where "equivalent" means that one column references the other via foreign key, or - # multiple columns that all reference a common parent column. it will also resolve the column - # against the "mapped_table" of this mapper. - - # TODO !!! - #primary_key = sqlutil.reduce_columns((self.primary_key_argument or self._pks_by_table[self.mapped_table])) - - # TODO !!! remove all this - primary_key = expression.ColumnSet() - - for col in (self.primary_key_argument or self._pks_by_table[self.mapped_table]): - c = self.mapped_table.corresponding_column(col) - if c is None: - for cc in self._equivalent_columns[col]: - c = self.mapped_table.corresponding_column(cc) - if c is not None: - break - else: - raise exceptions.ArgumentError("Cant resolve column " + str(col)) - - # this step attempts to resolve the column to an equivalent which is not - # a foreign key elsewhere. this helps with joined table inheritance - # so that PKs are expressed in terms of the base table which is always - # present in the initial select - # TODO: this is a little hacky right now, the "tried" list is to prevent - # endless loops between cyclical FKs, try to make this cleaner/work better/etc., - # perhaps via topological sort (pick the leftmost item) - tried = util.Set() - while True: - if not len(c.foreign_keys) or c in tried: - break - for cc in c.foreign_keys: - cc = cc.column - c2 = self.mapped_table.corresponding_column(cc) - if c2 is not None: - c = c2 - tried.add(c) - break - else: - break - primary_key.add(c) + # determine primary key from argument or mapped_table pks - reduce to the minimal set of columns + if self.primary_key_argument: + primary_key = sqlutil.reduce_columns([self.mapped_table.corresponding_column(c) for c in self.primary_key_argument]) + else: + primary_key = sqlutil.reduce_columns(self._pks_by_table[self.mapped_table]) if len(primary_key) == 0: - raise exceptions.ArgumentError("Could not assemble any primary key columns for mapped table '%s'" % (self.mapped_table.name)) - + raise exceptions.ArgumentError("Mapper %s could not assemble any primary key columns for mapped table '%s'" % (self, self.mapped_table.description)) + self.primary_key = primary_key self.__log("Identified primary key columns: " + str(primary_key)) @@ -730,15 +696,9 @@ class Mapper(object): if self.select_table is not self.mapped_table: # turn a straight join into an aliased selectable if isinstance(self.select_table, sql.Join): - if self.primary_key_argument: - primary_key_arg = self.primary_key_argument - else: - primary_key_arg = self.select_table.primary_key self.select_table = self.select_table.select(use_labels=True).alias() - else: - primary_key_arg = self.primary_key_argument - self.__surrogate_mapper = Mapper(self.class_, self.select_table, non_primary=True, _polymorphic_map=self.polymorphic_map, polymorphic_on=_corresponding_column_or_error(self.select_table, self.polymorphic_on), primary_key=primary_key_arg) + self.__surrogate_mapper = Mapper(self.class_, self.select_table, non_primary=True, _polymorphic_map=self.polymorphic_map, polymorphic_on=_corresponding_column_or_error(self.select_table, self.polymorphic_on), primary_key=self.primary_key_argument) adapter = sqlutil.ClauseAdapter(self.select_table, equivalents=self.__surrogate_mapper._equivalent_columns) if self.order_by: diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py index b817d29bc8..0d58a52161 100644 --- a/lib/sqlalchemy/orm/session.py +++ b/lib/sqlalchemy/orm/session.py @@ -572,7 +572,7 @@ class Session(object): This is equivalent to calling ``expunge()`` for all objects in this ``Session``. """ - + for instance in self: self._unattach(instance) self.uow = unitofwork.UnitOfWork(self) diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py index be870ee792..3ebc4960fa 100644 --- a/lib/sqlalchemy/sql/expression.py +++ b/lib/sqlalchemy/sql/expression.py @@ -31,7 +31,7 @@ from sqlalchemy import util, exceptions from sqlalchemy.sql import operators, visitors from sqlalchemy import types as sqltypes -functions, schema = None, None +functions, schema, sql_util = None, None, None DefaultDialect, ClauseAdapter = None, None __all__ = [ @@ -2179,51 +2179,14 @@ class Join(FromClause): columns = list(self._flatten_exportable_columns()) - #global sql_util - #if not sql_util: - # from sqlalchemy.sql import util as sql_util - #self._primary_key = sql_util.reduce_columns([c for c in columns if c.primary_key], self.onclause) - - self.__init_primary_key(columns) + global sql_util + if not sql_util: + from sqlalchemy.sql import util as sql_util + self._primary_key = sql_util.reduce_columns([c for c in columns if c.primary_key], self.onclause) for co in columns: cp = self._proxy_column(co) - def __init_primary_key(self, columns): - # TODO !!! remove all this - global schema - if schema is None: - from sqlalchemy import schema - pkcol = util.Set([c for c in columns if c.primary_key]) - - equivs = {} - def add_equiv(a, b): - for x, y in ((a, b), (b, a)): - if x in equivs: - equivs[x].add(y) - else: - equivs[x] = util.Set([y]) - - def visit_binary(binary): - if binary.operator == operators.eq and isinstance(binary.left, schema.Column) and isinstance(binary.right, schema.Column): - add_equiv(binary.left, binary.right) - visitors.traverse(self.onclause, visit_binary=visit_binary) - - for col in pkcol: - for fk in col.foreign_keys: - if fk.column in pkcol: - add_equiv(col, fk.column) - - omit = util.Set() - for col in pkcol: - p = col - for c in equivs.get(col, util.Set()): - if p.references(c) or (c.primary_key and not p.primary_key): - omit.add(p) - p = c - - self._primary_key = ColumnSet(pkcol.difference(omit)) - def description(self): return "Join object on %s(%d) and %s(%d)" % (self.left.description, id(self.left), self.right.description, id(self.right)) description = property(description) @@ -2284,6 +2247,12 @@ class Join(FromClause): """Returns the column list of this Join with all equivalently-named, equated columns folded into one column, where 'equated' means they are equated to each other in the ON clause of this join. + + this method is used by select(fold_equivalents=True). + + The primary usage for this is when generating UNIONs so that + each selectable can have distinctly-named columns without the need + for use_labels=True. """ if self.__folded_equivalents is not None: diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py index 0989cb43e9..93998c9a91 100644 --- a/lib/sqlalchemy/sql/util.py +++ b/lib/sqlalchemy/sql/util.py @@ -52,30 +52,42 @@ def find_columns(clause): def reduce_columns(columns, *clauses): - raise NotImplementedError() + """given a list of columns, return a 'reduced' set based on natural equivalents. + + the set is reduced to the smallest list of columns which have no natural + equivalent present in the list. A "natural equivalent" means that two columns + will ultimately represent the same value because they are related by a foreign key. + + \*clauses is an optional list of join clauses which will be traversed + to further identify columns that are "equivalent". - # TODO !!! - all_proxied_cols = util.Set(chain(*[c.proxy_set for c in columns])) + This function is primarily used to determine the most minimal "primary key" + from a selectable, by reducing the set of primary key columns present + in the the selectable to just those that are not repeated. + + """ columns = util.Set(columns) - equivs = {} + omit = util.Set() for col in columns: for fk in col.foreign_keys: - if fk.column in all_proxied_cols: - for c in columns: - if col.references(c): - equivs[col] = c + for c in columns: + if c is col: + continue + if fk.column.shares_lineage(c): + omit.add(col) + break if clauses: def visit_binary(binary): - if binary.operator == operators.eq and binary.left in columns and binary.right in columns: - equivs[binary.left] = binary.right + cols = columns.difference(omit) + if binary.operator == operators.eq and binary.left in cols and binary.right in cols: + omit.add(binary.right) for clause in clauses: visitors.traverse(clause, visit_binary=visit_binary) - result = util.Set([c for c in columns if c not in equivs]) - return expression.ColumnSet(result) + return expression.ColumnSet(columns.difference(omit)) class ColumnsInClause(visitors.ClauseVisitor): """Given a selectable, visit clauses and determine if any columns diff --git a/test/orm/inheritance/query.py b/test/orm/inheritance/query.py index b3239d3b3a..b9f11faa7c 100644 --- a/test/orm/inheritance/query.py +++ b/test/orm/inheritance/query.py @@ -159,7 +159,16 @@ def make_test(select_type): all_employees = [e1, e2, b1, m1, e3] c1_employees = [e1, e2, b1, m1] c2_employees = [e3] - + + def test_get(self): + sess = create_session() + + # for all mappers, ensure the primary key has been calculated as just the "person_id" + # column + self.assertEquals(sess.query(Person).get(e1.person_id), Engineer(name="dilbert")) + self.assertEquals(sess.query(Engineer).get(e1.person_id), Engineer(name="dilbert")) + self.assertEquals(sess.query(Manager).get(b1.person_id), Boss(name="pointy haired boss")) + def test_filter_on_subclass(self): sess = create_session() self.assertEquals(sess.query(Engineer).all()[0], Engineer(name="dilbert")) diff --git a/test/orm/mapper.py b/test/orm/mapper.py index 70cd81428d..a8f75a31b8 100644 --- a/test/orm/mapper.py +++ b/test/orm/mapper.py @@ -72,7 +72,7 @@ class MapperTest(MapperSuperTest): mapper(User, s) assert False except exceptions.ArgumentError, e: - assert str(e) == "Could not assemble any primary key columns for mapped table 'foo'" + assert "could not assemble any primary key columns for mapped table 'foo'" in str(e) def test_compileonsession(self): m = mapper(User, users) diff --git a/test/sql/selectable.py b/test/sql/selectable.py index a64697b81d..45bd7d823a 100755 --- a/test/sql/selectable.py +++ b/test/sql/selectable.py @@ -5,6 +5,7 @@ every selectable unit behaving nicely with others..""" import testenv; testenv.configure_for_tests() from sqlalchemy import * from testlib import * +from sqlalchemy.sql import util as sql_util metadata = MetaData() table = Table('table1', metadata, @@ -275,6 +276,124 @@ class PrimaryKeyTest(AssertMixin): assert str(j) == "a JOIN b ON a.id = b.id AND b.x = :b_x_1", str(j) assert list(j.primary_key) == [a.c.id, b.c.x] + def test_onclause_direction(self): + metadata = MetaData() + + employee = Table( 'Employee', metadata, + Column('name', String(100)), + Column('id', Integer, primary_key= True), + ) + + engineer = Table( 'Engineer', metadata, + Column('id', Integer, ForeignKey( 'Employee.id', ), primary_key=True), + ) + + self.assertEquals( + set(employee.join(engineer, employee.c.id==engineer.c.id).primary_key), + set([employee.c.id]) + ) + + self.assertEquals( + set(employee.join(engineer, engineer.c.id==employee.c.id).primary_key), + set([employee.c.id]) + ) + + +class ReduceTest(AssertMixin): + def test_reduce(self): + meta = MetaData() + t1 = Table('t1', meta, + Column('t1id', Integer, primary_key=True), + Column('t1data', String(30))) + t2 = Table('t2', meta, + Column('t2id', Integer, ForeignKey('t1.t1id'), primary_key=True), + Column('t2data', String(30))) + t3 = Table('t3', meta, + Column('t3id', Integer, ForeignKey('t2.t2id'), primary_key=True), + Column('t3data', String(30))) + + + self.assertEquals( + set(sql_util.reduce_columns([t1.c.t1id, t1.c.t1data, t2.c.t2id, t2.c.t2data, t3.c.t3id, t3.c.t3data])), + set([t1.c.t1id, t1.c.t1data, t2.c.t2data, t3.c.t3data]) + ) + + def test_reduce_aliased_join(self): + metadata = MetaData() + people = Table('people', metadata, + Column('person_id', Integer, Sequence('person_id_seq', optional=True), primary_key=True), + Column('name', String(50)), + Column('type', String(30))) + + engineers = Table('engineers', metadata, + Column('person_id', Integer, ForeignKey('people.person_id'), primary_key=True), + Column('status', String(30)), + Column('engineer_name', String(50)), + Column('primary_language', String(50)), + ) + + managers = Table('managers', metadata, + Column('person_id', Integer, ForeignKey('people.person_id'), primary_key=True), + Column('status', String(30)), + Column('manager_name', String(50)) + ) + + pjoin = people.outerjoin(engineers).outerjoin(managers).select(use_labels=True).alias('pjoin') + self.assertEquals( + set(sql_util.reduce_columns([pjoin.c.people_person_id, pjoin.c.engineers_person_id, pjoin.c.managers_person_id])), + set([pjoin.c.people_person_id]) + ) + + def test_reduce_aliased_union(self): + metadata = MetaData() + item_table = Table( + 'item', metadata, + Column('id', Integer, ForeignKey('base_item.id'), primary_key=True), + Column('dummy', Integer, default=0)) + + base_item_table = Table( + 'base_item', metadata, + Column('id', Integer, primary_key=True), + Column('child_name', String(255), default=None)) + + from sqlalchemy.orm.util import polymorphic_union + + item_join = polymorphic_union( { + 'BaseItem':base_item_table.select(base_item_table.c.child_name=='BaseItem'), + 'Item':base_item_table.join(item_table), + }, None, 'item_join') + + self.assertEquals( + set(sql_util.reduce_columns([item_join.c.id, item_join.c.dummy, item_join.c.child_name])), + set([item_join.c.id, item_join.c.dummy, item_join.c.child_name]) + ) + + def test_reduce_aliased_union_2(self): + metadata = MetaData() + + page_table = Table('page', metadata, + Column('id', Integer, primary_key=True), + ) + magazine_page_table = Table('magazine_page', metadata, + Column('page_id', Integer, ForeignKey('page.id'), primary_key=True), + ) + classified_page_table = Table('classified_page', metadata, + Column('magazine_page_id', Integer, ForeignKey('magazine_page.page_id'), primary_key=True), + ) + + from sqlalchemy.orm.util import polymorphic_union + pjoin = polymorphic_union( + { + 'm': page_table.join(magazine_page_table), + 'c': page_table.join(magazine_page_table).join(classified_page_table), + }, None, 'page_join') + + self.assertEquals( + set(sql_util.reduce_columns([pjoin.c.id, pjoin.c.page_id, pjoin.c.magazine_page_id])), + set([pjoin.c.id]) + ) + + class DerivedTest(AssertMixin): def test_table(self): meta = MetaData() -- 2.47.3