From 41df778985ce5b99935ff4b1ffa0c7249a03a83a Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Mon, 10 Dec 2007 04:31:17 +0000 Subject: [PATCH] - more query tests - trying to refine some of the adaptation stuff - query.from_statement() wont allow further generative criterion - added a warning to columncollection when selectable is formed with conflicting columns (only in the col export phase) - some method rearrangement on schema/columncollection.... - property conflicting relation warning doesnt raise for concrete --- CHANGES | 5 ++- lib/sqlalchemy/orm/properties.py | 7 ++-- lib/sqlalchemy/orm/query.py | 62 ++++++++++++++++---------------- lib/sqlalchemy/schema.py | 21 ++++++----- lib/sqlalchemy/sql/expression.py | 40 +++++++++++++++++---- test/orm/mapper.py | 3 +- test/orm/query.py | 37 +++++++++++++++++-- test/sql/selectable.py | 5 ++- 8 files changed, 124 insertions(+), 56 deletions(-) diff --git a/CHANGES b/CHANGES index e0df814694..6154bc3d18 100644 --- a/CHANGES +++ b/CHANGES @@ -36,7 +36,10 @@ CHANGES of the underlying type. Ideal for using with Unicode or Pickletype. TypeDecorator should now be the primary way to augment the behavior of any existing type including other TypeDecorator subclasses such as PickleType. - + + - selectables (and others) will issue a warning when two columns in + their exported columns collection conflict based on name. + - tables with schemas can still be used in sqlite, firebird, schema name just gets dropped [ticket:890] diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py index f0cae49d80..9394e9aead 100644 --- a/lib/sqlalchemy/orm/properties.py +++ b/lib/sqlalchemy/orm/properties.py @@ -399,9 +399,10 @@ class PropertyLoader(StrategizedProperty): # ensure the "select_mapper", if different from the regular target mapper, is compiled. self.mapper.get_select_mapper()._check_compile() - for inheriting in self.parent.iterate_to_root(): - if inheriting is not self.parent and inheriting.get_property(self.key, raiseerr=False): - warnings.warn(RuntimeWarning("Warning: relation '%s' on mapper '%s' supercedes the same relation on inherited mapper '%s'; this can cause dependency issues during flush" % (self.key, self.parent, inheriting))) + if not self.parent.concrete: + for inheriting in self.parent.iterate_to_root(): + if inheriting is not self.parent and inheriting.get_property(self.key, raiseerr=False): + warnings.warn(RuntimeWarning("Warning: relation '%s' on mapper '%s' supercedes the same relation on inherited mapper '%s'; this can cause dependency issues during flush" % (self.key, self.parent, inheriting))) if self.association is not None: if isinstance(self.association, type): diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index dbc62a47b9..0dbfdc611e 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -49,6 +49,7 @@ class Query(object): self._autoflush = True self._eager_loaders = util.Set(chain(*[mp._eager_loaders for mp in [m for m in self.mapper.iterate_to_root()]])) self._attributes = {} + self.__joinable_tables = {} self._current_path = () self._primary_adapter=None self._only_load_props = None @@ -66,7 +67,13 @@ class Query(object): q._statement = q._aliases = q._criterion = None q._order_by = q._group_by = q._distinct = False return q - + + def _no_statement(self, meth): + q = self._clone() + if q._statement: + raise exceptions.InvalidRequestError("Query.%s() being called on a Query with an existing full statement - can't apply criterion." % meth) + return q + def _clone(self): q = Query.__new__(Query) q.__dict__ = self.__dict__.copy() @@ -322,10 +329,10 @@ class Query(object): if self._aliases is not None: criterion = self._aliases.adapt_clause(criterion) - elif self._from_obj is not self.table: + elif self.table not in self._get_joinable_tables(): criterion = sql_util.ClauseAdapter(self._from_obj).traverse(criterion) - q = self._clone() + q = self._no_statement("filter") if q._criterion is not None: q._criterion = q._criterion & criterion else: @@ -341,12 +348,14 @@ class Query(object): return self.filter(sql.and_(*clauses)) def _get_joinable_tables(self): - currenttables = [self._from_obj] - def visit_join(join): - currenttables.append(join.left) - currenttables.append(join.right) - visitors.traverse(self._from_obj, visit_join=visit_join, traverse_options={'column_collections':False, 'aliased_selectables':False}) - return currenttables + if self._from_obj not in self.__joinable_tables: + currenttables = [self._from_obj] + def visit_join(join): + currenttables.append(join.left) + currenttables.append(join.right) + visitors.traverse(self._from_obj, visit_join=visit_join, traverse_options={'column_collections':False, 'aliased_selectables':False}) + self.__joinable_tables = {self._from_obj : currenttables} + return self.__joinable_tables[self._from_obj] def _join_to(self, keys, outerjoin=False, start=None, create_aliases=True): if start is None: @@ -416,7 +425,7 @@ class Query(object): """ if self._column_aggregate is not None: raise exceptions.InvalidRequestError("Query already contains an aggregate column or function") - q = self._clone() + q = self._no_statement("aggregate") q._column_aggregate = (col, func) return q @@ -484,7 +493,7 @@ class Query(object): def order_by(self, criterion): """apply one or more ORDER BY criterion to the query and return the newly resulting ``Query``""" - q = self._clone() + q = self._no_statement("order_by") if q._order_by is False: q._order_by = util.to_list(criterion) else: @@ -494,7 +503,7 @@ class Query(object): def group_by(self, criterion): """apply one or more GROUP BY criterion to the query and return the newly resulting ``Query``""" - q = self._clone() + q = self._no_statement("group_by") if q._group_by is False: q._group_by = util.to_list(criterion) else: @@ -514,7 +523,7 @@ class Query(object): if self._aliases is not None: criterion = self._aliases.adapt_clause(criterion) - q = self._clone() + q = self._no_statement("having") if q._having is not None: q._having = q._having & criterion else: @@ -543,7 +552,7 @@ class Query(object): def _join(self, prop, id, outerjoin, aliased, from_joinpoint): (clause, mapper, aliases) = self._join_to(prop, outerjoin=outerjoin, start=from_joinpoint and self._joinpoint or self.mapper, create_aliases=aliased) - q = self._clone() + q = self._no_statement("join") q._from_obj = clause q._joinpoint = mapper q._aliases = aliases @@ -568,7 +577,7 @@ class Query(object): the root. """ - q = self._clone() + q = self._no_statement("reset_joinpoint") q._joinpoint = q.mapper q._aliases = None return q @@ -638,7 +647,7 @@ class Query(object): ``Query``. """ - new = self._clone() + new = self._no_statement("distinct") new._distinct = True return new @@ -828,7 +837,7 @@ class Query(object): if lockmode is not None: q = q.with_lockmode(lockmode) q = q._select_context_options(populate_existing=refresh_instance is not None, version_check=(lockmode is not None), only_load_props=only_load_props, refresh_instance=refresh_instance) - q = q.order_by(None) + q._order_by = None # call using all() to avoid LIMIT compilation complexity return q.all()[0] except IndexError: @@ -891,23 +900,14 @@ class Query(object): whereclause = self._criterion from_obj = self._from_obj - currenttables = self._get_joinable_tables() - adapt_criterion = self.table not in currenttables + adapt_criterion = self.table not in self._get_joinable_tables() - if whereclause is not None and (self.mapper is not self.select_mapper): - # adapt the given WHERECLAUSE to adjust instances of this query's mapped - # table to be that of our select_table, - # which may be the "polymorphic" selectable used by our mapper. + if not adapt_criterion and whereclause is not None and (self.mapper is not self.select_mapper): whereclause = sql_util.ClauseAdapter(from_obj).traverse(whereclause, stop_on=util.Set([from_obj])) - # if extra entities, adapt the criterion to those as well - for m in self._entities: - if isinstance(m, type): - m = mapper.class_mapper(m) - if isinstance(m, mapper.Mapper): - sql_util.ClauseAdapter(m.select_table).traverse(whereclause, stop_on=util.Set([m.select_table])) + # TODO: mappers added via add_entity(), adapt their queries also, + # if those mappers are polymorphic - order_by = self._order_by if order_by is False: order_by = self.mapper.order_by @@ -969,7 +969,7 @@ class Query(object): if adapt_criterion: context.primary_columns = [from_obj.corresponding_column(c, raiseerr=False) or c for c in context.primary_columns] cf = [from_obj.corresponding_column(c, raiseerr=False) or c for c in cf] - + s2 = sql.select(context.primary_columns + list(cf), whereclause, from_obj=context.from_clause, use_labels=True, correlate=False, order_by=util.to_list(order_by), **self._select_args()) s3 = s2.alias() diff --git a/lib/sqlalchemy/schema.py b/lib/sqlalchemy/schema.py index 8179810033..15b35b96a3 100644 --- a/lib/sqlalchemy/schema.py +++ b/lib/sqlalchemy/schema.py @@ -493,6 +493,7 @@ class Column(SchemaItem, expression._ColumnClause): [repr(self.name)] + [repr(self.type)] + [repr(x) for x in self.foreign_keys if x is not None] + [repr(x) for x in self.constraints] + + [(self.table and "table=<%s>" % self.table.description or "")] + ["%s=%s" % (k, repr(getattr(self, k))) for k in kwarg]) def _get_parent(self): @@ -504,12 +505,13 @@ class Column(SchemaItem, expression._ColumnClause): raise exceptions.ArgumentError("this Column already has a table!") if not self._is_oid: self._pre_existing_column = table._columns.get(self.key) - table._columns.add(self) + + table._columns.replace(self) else: self._pre_existing_column = None if self.primary_key: - table.primary_key.add(self) + table.primary_key.replace(self) elif self.key in table.primary_key: raise exceptions.ArgumentError("Trying to redefine primary-key column '%s' as a non-primary-key column on table '%s'" % (self.key, table.fullname)) # if we think this should not raise an error, we'd instead do this: @@ -899,19 +901,20 @@ class PrimaryKeyConstraint(Constraint): self.table = table table.primary_key = self for c in self.__colnames: - self.append_column(table.c[c]) - + self.add(table.c[c]) + def add(self, col): - self.append_column(col) + self.columns.add(col) + col.primary_key=True + append_column = add + + def replace(self, col): + self.columns.replace(col) def remove(self, col): col.primary_key=False del self.columns[col.key] - def append_column(self, col): - self.columns.add(col) - col.primary_key=True - def copy(self): return PrimaryKeyConstraint(name=self.name, *[c.key for c in self]) diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py index 7caee33144..dabc10decb 100644 --- a/lib/sqlalchemy/sql/expression.py +++ b/lib/sqlalchemy/sql/expression.py @@ -27,6 +27,7 @@ to stay the same in future releases. import re import datetime +import warnings from sqlalchemy import util, exceptions from sqlalchemy.sql import operators, visitors from sqlalchemy import types as sqltypes @@ -1464,6 +1465,27 @@ class ColumnCollection(util.OrderedProperties): def __str__(self): return repr([str(c) for c in self]) + def replace(self, column): + """add the given column to this collection, removing unaliased versions of this column + as well as existing columns with the same key. + + e.g.:: + + t = Table('sometable', Column('col1', Integer)) + t.replace_unalised(Column('col1', Integer, key='columnone')) + + will remove the original 'col1' from the collection, and add + the new column under the name 'columnname'. + + Used by schema.Column to override columns during table reflection. + """ + + if column.name in self and column.key != column.name: + other = self[column.name] + if other.name == other.key: + del self[other.name] + util.OrderedProperties.__setitem__(self, column.key, column) + def add(self, column): """Add a column to this collection. @@ -1471,14 +1493,18 @@ class ColumnCollection(util.OrderedProperties): for this dictionary. """ - # Allow an aliased column to replace an unaliased column of the - # same name. - if column.name in self: - other = self[column.name] - if other.name == other.key: - del self[other.name] self[column.key] = column - + + def __setitem__(self, key, value): + if key in self: + # this warning is primarily to catch select() statements which have conflicting + # column names in their exported columns collection + existing = self[key] + if not existing.shares_lineage(value): + table = getattr(existing, 'table', None) and existing.table.description + warnings.warn(RuntimeWarning("Column %r on table %r being replaced by another column with the same key. Consider use_labels for select() statements." % (key, table))) + util.OrderedProperties.__setitem__(self, key, value) + def remove(self, column): del self[column.key] diff --git a/test/orm/mapper.py b/test/orm/mapper.py index 160a315524..65a6ad8fa0 100644 --- a/test/orm/mapper.py +++ b/test/orm/mapper.py @@ -1232,8 +1232,7 @@ class RequirementsTest(AssertMixin): t5 = Table('ht5', metadata, Column('ht1_id', Integer, ForeignKey('ht1.id'), primary_key=True), - Column('ht1_id', Integer, ForeignKey('ht1.id'), - primary_key=True)) + ) t6 = Table('ht6', metadata, Column('ht1a_id', Integer, ForeignKey('ht1.id'), primary_key=True), diff --git a/test/orm/query.py b/test/orm/query.py index 09c6c2144c..5f85151f0a 100644 --- a/test/orm/query.py +++ b/test/orm/query.py @@ -486,7 +486,29 @@ class ParentTest(QueryTest): class JoinTest(QueryTest): - + + def test_getjoinable_tables(self): + sess = create_session() + + sel1 = select([users]).alias() + sel2 = select([users], from_obj=users.join(addresses)).alias() + + j1 = sel1.join(users, sel1.c.id==users.c.id) + j2 = j1.join(addresses) + + for from_obj, assert_cond in ( + (users, [users]), + (users.join(addresses), [users, addresses]), + (sel1, [sel1]), + (sel2, [sel2]), + (sel1.join(users, sel1.c.id==users.c.id), [sel1, users]), + (sel2.join(users, sel2.c.id==users.c.id), [sel2, users]), + (j2, [j1, j2, sel1, users, addresses]) + + ): + ret = set(sess.query(User).select_from(from_obj)._get_joinable_tables()) + self.assertEquals(ret, set(assert_cond).union([from_obj]), [x.description for x in ret]) + def test_overlapping_paths(self): for aliased in (True,False): # load a user who has an order that contains item id 3 and address id 1 (order 3, owned by jack) @@ -995,6 +1017,18 @@ class SelectFromTest(QueryTest): Order(description=u'order 5',items=[Item(description=u'item 5',keywords=[])])]) ]) self.assert_sql_count(testbase.db, go, 1) + + sess.clear() + sel2 = orders.select(orders.c.id.in_([1,2,3])) + self.assertEquals(sess.query(Order).select_from(sel2).join(['items', 'keywords']).filter(Keyword.name == 'red').all(), [ + Order(description=u'order 1',id=1), + Order(description=u'order 2',id=2), + ]) + self.assertEquals(sess.query(Order).select_from(sel2).join(['items', 'keywords'], aliased=True).filter(Keyword.name == 'red').all(), [ + Order(description=u'order 1',id=1), + Order(description=u'order 2',id=2), + ]) + def test_replace_with_eager(self): mapper(User, users, properties = { @@ -1026,7 +1060,6 @@ class SelectFromTest(QueryTest): self.assertEquals(sess.query(User).options(eagerload('addresses')).select_from(sel)[1], User(id=8, addresses=[Address(id=2), Address(id=3), Address(id=4)])) self.assert_sql_count(testbase.db, go, 1) - class CustomJoinTest(QueryTest): keep_mappers = False diff --git a/test/sql/selectable.py b/test/sql/selectable.py index 49a61bf2b9..4796288dfa 100755 --- a/test/sql/selectable.py +++ b/test/sql/selectable.py @@ -47,7 +47,10 @@ class SelectableTest(AssertMixin): j2 = jjj.alias('foo') assert j2.corresponding_column(table.c.col1) is j2.c.table1_col1 - + def testselectontable(self): + sel = select([table, table2], use_labels=True) + assert sel.corresponding_column(table.c.col1) is sel.c.table1_col1 + def testjoinagainstjoin(self): j = outerjoin(table, table2, table.c.col1==table2.c.col2) jj = select([ table.c.col1.label('bar_col1')],from_obj=[j]).alias('foo') -- 2.47.3