From: Mike Bayer Date: Tue, 15 Jan 2008 02:34:17 +0000 (+0000) Subject: - select_table mapper turns straight join into aliased select + custom PK, to allow X-Git-Tag: rel_0_4_3~106 X-Git-Url: http://git.ipfire.org/gitweb.cgi?a=commitdiff_plain;h=4870a41d277ef8638dd06d23ba20a69acf073739;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git - select_table mapper turns straight join into aliased select + custom PK, to allow joins onto select_table mappers - starting a generalized reduce_columns func --- diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index c733c68ad2..61f5a65791 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -443,6 +443,10 @@ class Mapper(object): # multiple columns that all reference a common parent column. it will also resolve the column # against the "mapped_table" of this mapper. + # TODO !!! + #primary_key = sqlutil.reduce_columns((self.primary_key_argument or self._pks_by_table[self.mapped_table])) + + # TODO !!! remove all this primary_key = expression.ColumnSet() for col in (self.primary_key_argument or self._pks_by_table[self.mapped_table]): @@ -724,7 +728,17 @@ class Mapper(object): """ if self.select_table is not self.mapped_table: - self.__surrogate_mapper = Mapper(self.class_, self.select_table, non_primary=True, _polymorphic_map=self.polymorphic_map, polymorphic_on=_corresponding_column_or_error(self.select_table, self.polymorphic_on), primary_key=self.primary_key_argument) + # turn a straight join into an aliased selectable + if isinstance(self.select_table, sql.Join): + if self.primary_key_argument: + primary_key_arg = self.primary_key_argument + else: + primary_key_arg = self.select_table.primary_key + self.select_table = self.select_table.select(use_labels=True).alias() + else: + primary_key_arg = self.primary_key_argument + + self.__surrogate_mapper = Mapper(self.class_, self.select_table, non_primary=True, _polymorphic_map=self.polymorphic_map, polymorphic_on=_corresponding_column_or_error(self.select_table, self.polymorphic_on), primary_key=primary_key_arg) adapter = sqlutil.ClauseAdapter(self.select_table, equivalents=self.__surrogate_mapper._equivalent_columns) if self.order_by: diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py index 4eb555d4d4..be870ee792 100644 --- a/lib/sqlalchemy/sql/expression.py +++ b/lib/sqlalchemy/sql/expression.py @@ -2178,11 +2178,19 @@ class Join(FromClause): self._foreign_keys = util.Set() columns = list(self._flatten_exportable_columns()) + + #global sql_util + #if not sql_util: + # from sqlalchemy.sql import util as sql_util + #self._primary_key = sql_util.reduce_columns([c for c in columns if c.primary_key], self.onclause) + self.__init_primary_key(columns) + for co in columns: cp = self._proxy_column(co) def __init_primary_key(self, columns): + # TODO !!! remove all this global schema if schema is None: from sqlalchemy import schema diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py index c2ac26557e..0989cb43e9 100644 --- a/lib/sqlalchemy/sql/util.py +++ b/lib/sqlalchemy/sql/util.py @@ -1,5 +1,6 @@ from sqlalchemy import util, schema, topological -from sqlalchemy.sql import expression, visitors +from sqlalchemy.sql import expression, visitors, operators +from itertools import chain """Utility functions that build upon SQL and Schema constructs.""" @@ -49,6 +50,33 @@ def find_columns(clause): visitors.traverse(clause, visit_column=visit_column) return cols + +def reduce_columns(columns, *clauses): + raise NotImplementedError() + + # TODO !!! + all_proxied_cols = util.Set(chain(*[c.proxy_set for c in columns])) + + columns = util.Set(columns) + + equivs = {} + for col in columns: + for fk in col.foreign_keys: + if fk.column in all_proxied_cols: + for c in columns: + if col.references(c): + equivs[col] = c + + if clauses: + def visit_binary(binary): + if binary.operator == operators.eq and binary.left in columns and binary.right in columns: + equivs[binary.left] = binary.right + for clause in clauses: + visitors.traverse(clause, visit_binary=visit_binary) + + result = util.Set([c for c in columns if c not in equivs]) + return expression.ColumnSet(result) + class ColumnsInClause(visitors.ClauseVisitor): """Given a selectable, visit clauses and determine if any columns from the clause are in the selectable. diff --git a/test/orm/inheritance/query.py b/test/orm/inheritance/query.py index 2e3d392556..b3239d3b3a 100644 --- a/test/orm/inheritance/query.py +++ b/test/orm/inheritance/query.py @@ -187,16 +187,13 @@ def make_test(select_type): self.assertEquals(sess.query(Person).join('paperwork', aliased=aliased).filter(Person.c.name.like('%dog%')).filter(Paperwork.description.like('%#2%')).all(), [m1]) def test_join_to_polymorphic(self): - if select_type == 'Joins': - return - sess = create_session() self.assertEquals(sess.query(Company).join('employees').filter(Person.name=='vlad').one(), c2) self.assertEquals(sess.query(Company).join('employees', aliased=True).filter(Person.name=='vlad').one(), c2) def test_join_to_subclass(self): - if select_type in ('Joins', ''): + if select_type == '': return sess = create_session() @@ -214,8 +211,6 @@ def make_test(select_type): self.assertEquals(sess.query(Company).join(['employees', Engineer.machines]).filter(Machine.name.ilike("%thinkpad%")).all(), [c1]) def test_join_through_polymorphic(self): - if select_type == 'Joins': - return sess = create_session()