From: Mike Bayer Date: Wed, 1 Apr 2015 23:18:36 +0000 (-0400) Subject: - :class:`.Query` doesn't support joins, subselects, or special X-Git-Tag: rel_1_0_0b5~4 X-Git-Url: http://git.ipfire.org/gitweb.cgi?a=commitdiff_plain;h=fe1922764151454460dfabfd574d3ead12edf543;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git - :class:`.Query` doesn't support joins, subselects, or special FROM clauses when using the :meth:`.Query.update` or :meth:`.Query.delete` methods; instead of silently ignoring these fields if methods like :meth:`.Query.join` or :meth:`.Query.select_from` has been called, an error is raised. In 0.9.10 this only emits a warning. fixes #3349 - don't needlessly call _compile_context() and build up a whole statement that we never need. Construct QueryContext as it's part of the event contract, but don't actually call upon mapper attributes; use more direct systems of determining the update or delete table. - don't realy need _no_select_modifiers anymore --- diff --git a/doc/build/changelog/changelog_10.rst b/doc/build/changelog/changelog_10.rst index cf9ebc1a48..95bb7d0f39 100644 --- a/doc/build/changelog/changelog_10.rst +++ b/doc/build/changelog/changelog_10.rst @@ -18,6 +18,17 @@ .. changelog:: :version: 1.0.0b5 + .. change:: + :tags: bug, orm + :tickets: 3349 + + :class:`.Query` doesn't support joins, subselects, or special + FROM clauses when using the :meth:`.Query.update` or + :meth:`.Query.delete` methods; instead of silently ignoring these + fields if methods like :meth:`.Query.join` or + :meth:`.Query.select_from` has been called, an error is raised. + In 0.9.10 this only emits a warning. + .. change:: :tags: bug, orm diff --git a/lib/sqlalchemy/orm/persistence.py b/lib/sqlalchemy/orm/persistence.py index c5df63dfda..ff5dda7b3a 100644 --- a/lib/sqlalchemy/orm/persistence.py +++ b/lib/sqlalchemy/orm/persistence.py @@ -16,10 +16,11 @@ in unitofwork.py. import operator from itertools import groupby, chain -from .. import sql, util, exc as sa_exc, schema +from .. import sql, util, exc as sa_exc from . import attributes, sync, exc as orm_exc, evaluator from .base import state_str, _attr_as_key, _entity_descriptor from ..sql import expression +from ..sql.base import _from_objects from . import loading @@ -1031,6 +1032,26 @@ class BulkUD(object): def __init__(self, query): self.query = query.enable_eagerloads(False) self.mapper = self.query._bind_mapper() + self._validate_query_state() + + def _validate_query_state(self): + for attr, methname, notset in ( + ('_limit', 'limit()', None), + ('_offset', 'offset()', None), + ('_order_by', 'order_by()', False), + ('_group_by', 'group_by()', False), + ('_distinct', 'distinct()', False), + ( + '_from_obj', + 'join(), outerjoin(), select_from(), or from_self()', + ()) + ): + if getattr(self.query, attr) is not notset: + raise sa_exc.InvalidRequestError( + "Can't call Query.update() or Query.delete() " + "when %s has been called" % + (methname, ) + ) @property def session(self): @@ -1055,18 +1076,34 @@ class BulkUD(object): self._do_post_synchronize() self._do_post() - def _do_pre(self): + @util.dependencies("sqlalchemy.orm.query") + def _do_pre(self, querylib): query = self.query - self.context = context = query._compile_context() - if len(context.statement.froms) != 1 or \ - not isinstance(context.statement.froms[0], schema.Table): + self.context = querylib.QueryContext(query) + + if isinstance(query._entities[0], querylib._ColumnEntity): + # check for special case of query(table) + tables = set() + for ent in query._entities: + if not isinstance(ent, querylib._ColumnEntity): + tables.clear() + break + else: + tables.update(_from_objects(ent.column)) + + if len(tables) != 1: + raise sa_exc.InvalidRequestError( + "This operation requires only one Table or " + "entity be specified as the target." + ) + else: + self.primary_table = tables.pop() + else: self.primary_table = query._only_entity_zero( "This operation requires only one Table or " "entity be specified as the target." ).mapper.local_table - else: - self.primary_table = context.statement.froms[0] session = query.session @@ -1121,7 +1158,8 @@ class BulkFetch(BulkUD): def _do_pre_synchronize(self): query = self.query session = query.session - select_stmt = self.context.statement.with_only_columns( + context = query._compile_context() + select_stmt = context.statement.with_only_columns( self.primary_table.primary_key) self.matched_rows = session.execute( select_stmt, @@ -1134,7 +1172,6 @@ class BulkUpdate(BulkUD): def __init__(self, query, values): super(BulkUpdate, self).__init__(query) - self.query._no_select_modifiers("update") self.values = values @classmethod @@ -1195,7 +1232,6 @@ class BulkDelete(BulkUD): def __init__(self, query): super(BulkDelete, self).__init__(query) - self.query._no_select_modifiers("delete") @classmethod def factory(cls, query, synchronize_session): diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index 36180e8d5b..9aa2e3d991 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -399,22 +399,6 @@ class Query(object): % (meth, meth) ) - def _no_select_modifiers(self, meth): - if not self._enable_assertions: - return - for attr, methname, notset in ( - ('_limit', 'limit()', None), - ('_offset', 'offset()', None), - ('_order_by', 'order_by()', False), - ('_group_by', 'group_by()', False), - ('_distinct', 'distinct()', False), - ): - if getattr(self, attr) is not notset: - raise sa_exc.InvalidRequestError( - "Can't call Query.%s() when %s has been called" % - (meth, methname) - ) - def _get_options(self, populate_existing=None, version_check=None, only_load_props=None, diff --git a/test/orm/test_query.py b/test/orm/test_query.py index fc9e6e3afe..4c909d6aaf 100644 --- a/test/orm/test_query.py +++ b/test/orm/test_query.py @@ -702,39 +702,6 @@ class InvalidGenerationsTest(QueryTest, AssertsCompiledSQL): text("select * from table")) assert_raises(sa_exc.InvalidRequestError, q.with_polymorphic, User) - def test_cancel_order_by(self): - User = self.classes.User - - s = create_session() - - q = s.query(User).order_by(User.id) - self.assert_compile( - q, - "SELECT users.id AS users_id, users.name AS users_name " - "FROM users ORDER BY users.id", - use_default_dialect=True) - - assert_raises( - sa_exc.InvalidRequestError, q._no_select_modifiers, "foo") - - q = q.order_by(None) - self.assert_compile( - q, - "SELECT users.id AS users_id, users.name AS users_name FROM users", - use_default_dialect=True) - - assert_raises( - sa_exc.InvalidRequestError, q._no_select_modifiers, "foo") - - q = q.order_by(False) - self.assert_compile( - q, - "SELECT users.id AS users_id, users.name AS users_name FROM users", - use_default_dialect=True) - - # after False was set, this should pass - q._no_select_modifiers("foo") - def test_mapper_zero(self): User, Address = self.classes.User, self.classes.Address diff --git a/test/orm/test_update_delete.py b/test/orm/test_update_delete.py index a3ad37e605..dedc2133bd 100644 --- a/test/orm/test_update_delete.py +++ b/test/orm/test_update_delete.py @@ -19,12 +19,20 @@ class UpdateDeleteTest(fixtures.MappedTest): test_needs_autoincrement=True), Column('name', String(32)), Column('age_int', Integer)) + Table( + "addresses", metadata, + Column('id', Integer, primary_key=True), + Column('user_id', ForeignKey('users.id')) + ) @classmethod def setup_classes(cls): class User(cls.Comparable): pass + class Address(cls.Comparable): + pass + @classmethod def insert_data(cls): users = cls.tables.users @@ -41,9 +49,14 @@ class UpdateDeleteTest(fixtures.MappedTest): User = cls.classes.User users = cls.tables.users + Address = cls.classes.Address + addresses = cls.tables.addresses + mapper(User, users, properties={ - 'age': users.c.age_int + 'age': users.c.age_int, + 'addresses': relationship(Address) }) + mapper(Address, addresses) def test_illegal_eval(self): User = self.classes.User @@ -59,27 +72,36 @@ class UpdateDeleteTest(fixtures.MappedTest): def test_illegal_operations(self): User = self.classes.User + Address = self.classes.Address s = Session() for q, mname in ( - (s.query(User).limit(2), "limit"), - (s.query(User).offset(2), "offset"), - (s.query(User).limit(2).offset(2), "limit"), - (s.query(User).order_by(User.id), "order_by"), - (s.query(User).group_by(User.id), "group_by"), - (s.query(User).distinct(), "distinct") + (s.query(User).limit(2), r"limit\(\)"), + (s.query(User).offset(2), r"offset\(\)"), + (s.query(User).limit(2).offset(2), r"limit\(\)"), + (s.query(User).order_by(User.id), r"order_by\(\)"), + (s.query(User).group_by(User.id), r"group_by\(\)"), + (s.query(User).distinct(), r"distinct\(\)"), + (s.query(User).join(User.addresses), + r"join\(\), outerjoin\(\), select_from\(\), or from_self\(\)"), + (s.query(User).outerjoin(User.addresses), + r"join\(\), outerjoin\(\), select_from\(\), or from_self\(\)"), + (s.query(User).select_from(Address), + r"join\(\), outerjoin\(\), select_from\(\), or from_self\(\)"), + (s.query(User).from_self(), + r"join\(\), outerjoin\(\), select_from\(\), or from_self\(\)"), ): assert_raises_message( exc.InvalidRequestError, - r"Can't call Query.update\(\) when " - "%s\(\) has been called" % mname, + r"Can't call Query.update\(\) or Query.delete\(\) when " + "%s has been called" % mname, q.update, {'name': 'ed'}) assert_raises_message( exc.InvalidRequestError, - r"Can't call Query.delete\(\) when " - "%s\(\) has been called" % mname, + r"Can't call Query.update\(\) or Query.delete\(\) when " + "%s has been called" % mname, q.delete) def test_evaluate_clauseelement(self):