]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- :class:`.Query` doesn't support joins, subselects, or special
authorMike Bayer <mike_mp@zzzcomputing.com>
Wed, 1 Apr 2015 23:18:36 +0000 (19:18 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Wed, 1 Apr 2015 23:37:43 +0000 (19:37 -0400)
FROM clauses when using the :meth:`.Query.update` or
:meth:`.Query.delete` methods; instead of silently ignoring these
fields if methods like :meth:`.Query.join` or
:meth:`.Query.select_from` has been called, an error is raised.
In 0.9.10 this only emits a warning.
fixes #3349
- don't needlessly call _compile_context() and build up a
whole statement that we never need.  Construct QueryContext
as it's part of the event contract, but don't actually call upon
mapper attributes; use more direct systems of determining the
update or delete table.
- don't realy need _no_select_modifiers anymore

doc/build/changelog/changelog_10.rst
lib/sqlalchemy/orm/persistence.py
lib/sqlalchemy/orm/query.py
test/orm/test_query.py
test/orm/test_update_delete.py

index cf9ebc1a488d800bd4e260b8ab290343122752fe..95bb7d0f394cf5a2debde2483651ff639697c190 100644 (file)
 .. changelog::
     :version: 1.0.0b5
 
+    .. change::
+        :tags: bug, orm
+        :tickets: 3349
+
+        :class:`.Query` doesn't support joins, subselects, or special
+        FROM clauses when using the :meth:`.Query.update` or
+        :meth:`.Query.delete` methods; instead of silently ignoring these
+        fields if methods like :meth:`.Query.join` or
+        :meth:`.Query.select_from` has been called, an error is raised.
+        In 0.9.10 this only emits a warning.
+
     .. change::
         :tags: bug, orm
 
index c5df63dfda077c6ac66f8251a7b3c8ab59724fce..ff5dda7b3a74f619e924260ebdb0b9fe3fd66f68 100644 (file)
@@ -16,10 +16,11 @@ in unitofwork.py.
 
 import operator
 from itertools import groupby, chain
-from .. import sql, util, exc as sa_exc, schema
+from .. import sql, util, exc as sa_exc
 from . import attributes, sync, exc as orm_exc, evaluator
 from .base import state_str, _attr_as_key, _entity_descriptor
 from ..sql import expression
+from ..sql.base import _from_objects
 from . import loading
 
 
@@ -1031,6 +1032,26 @@ class BulkUD(object):
     def __init__(self, query):
         self.query = query.enable_eagerloads(False)
         self.mapper = self.query._bind_mapper()
+        self._validate_query_state()
+
+    def _validate_query_state(self):
+        for attr, methname, notset in (
+            ('_limit', 'limit()', None),
+            ('_offset', 'offset()', None),
+            ('_order_by', 'order_by()', False),
+            ('_group_by', 'group_by()', False),
+            ('_distinct', 'distinct()', False),
+            (
+                '_from_obj',
+                'join(), outerjoin(), select_from(), or from_self()',
+                ())
+        ):
+            if getattr(self.query, attr) is not notset:
+                raise sa_exc.InvalidRequestError(
+                    "Can't call Query.update() or Query.delete() "
+                    "when %s has been called" %
+                    (methname, )
+                )
 
     @property
     def session(self):
@@ -1055,18 +1076,34 @@ class BulkUD(object):
         self._do_post_synchronize()
         self._do_post()
 
-    def _do_pre(self):
+    @util.dependencies("sqlalchemy.orm.query")
+    def _do_pre(self, querylib):
         query = self.query
-        self.context = context = query._compile_context()
-        if len(context.statement.froms) != 1 or \
-                not isinstance(context.statement.froms[0], schema.Table):
+        self.context = querylib.QueryContext(query)
+
+        if isinstance(query._entities[0], querylib._ColumnEntity):
+            # check for special case of query(table)
+            tables = set()
+            for ent in query._entities:
+                if not isinstance(ent, querylib._ColumnEntity):
+                    tables.clear()
+                    break
+                else:
+                    tables.update(_from_objects(ent.column))
+
+            if len(tables) != 1:
+                raise sa_exc.InvalidRequestError(
+                    "This operation requires only one Table or "
+                    "entity be specified as the target."
+                )
+            else:
+                self.primary_table = tables.pop()
 
+        else:
             self.primary_table = query._only_entity_zero(
                 "This operation requires only one Table or "
                 "entity be specified as the target."
             ).mapper.local_table
-        else:
-            self.primary_table = context.statement.froms[0]
 
         session = query.session
 
@@ -1121,7 +1158,8 @@ class BulkFetch(BulkUD):
     def _do_pre_synchronize(self):
         query = self.query
         session = query.session
-        select_stmt = self.context.statement.with_only_columns(
+        context = query._compile_context()
+        select_stmt = context.statement.with_only_columns(
             self.primary_table.primary_key)
         self.matched_rows = session.execute(
             select_stmt,
@@ -1134,7 +1172,6 @@ class BulkUpdate(BulkUD):
 
     def __init__(self, query, values):
         super(BulkUpdate, self).__init__(query)
-        self.query._no_select_modifiers("update")
         self.values = values
 
     @classmethod
@@ -1195,7 +1232,6 @@ class BulkDelete(BulkUD):
 
     def __init__(self, query):
         super(BulkDelete, self).__init__(query)
-        self.query._no_select_modifiers("delete")
 
     @classmethod
     def factory(cls, query, synchronize_session):
index 36180e8d5b5a48ca496068202b37f8f02892f1b6..9aa2e3d991b1e6d6f4adf36e6bc1a4780ee3a82b 100644 (file)
@@ -399,22 +399,6 @@ class Query(object):
                 % (meth, meth)
             )
 
-    def _no_select_modifiers(self, meth):
-        if not self._enable_assertions:
-            return
-        for attr, methname, notset in (
-            ('_limit', 'limit()', None),
-            ('_offset', 'offset()', None),
-            ('_order_by', 'order_by()', False),
-            ('_group_by', 'group_by()', False),
-            ('_distinct', 'distinct()', False),
-        ):
-            if getattr(self, attr) is not notset:
-                raise sa_exc.InvalidRequestError(
-                    "Can't call Query.%s() when %s has been called" %
-                    (meth, methname)
-                )
-
     def _get_options(self, populate_existing=None,
                      version_check=None,
                      only_load_props=None,
index fc9e6e3afe335c5e47f96fbab80d476e14c6fda8..4c909d6aaf5b05d24bc880a5731f24e14951844d 100644 (file)
@@ -702,39 +702,6 @@ class InvalidGenerationsTest(QueryTest, AssertsCompiledSQL):
             text("select * from table"))
         assert_raises(sa_exc.InvalidRequestError, q.with_polymorphic, User)
 
-    def test_cancel_order_by(self):
-        User = self.classes.User
-
-        s = create_session()
-
-        q = s.query(User).order_by(User.id)
-        self.assert_compile(
-            q,
-            "SELECT users.id AS users_id, users.name AS users_name "
-            "FROM users ORDER BY users.id",
-            use_default_dialect=True)
-
-        assert_raises(
-            sa_exc.InvalidRequestError, q._no_select_modifiers, "foo")
-
-        q = q.order_by(None)
-        self.assert_compile(
-            q,
-            "SELECT users.id AS users_id, users.name AS users_name FROM users",
-            use_default_dialect=True)
-
-        assert_raises(
-            sa_exc.InvalidRequestError, q._no_select_modifiers, "foo")
-
-        q = q.order_by(False)
-        self.assert_compile(
-            q,
-            "SELECT users.id AS users_id, users.name AS users_name FROM users",
-            use_default_dialect=True)
-
-        # after False was set, this should pass
-        q._no_select_modifiers("foo")
-
     def test_mapper_zero(self):
         User, Address = self.classes.User, self.classes.Address
 
index a3ad37e60573d01807da161d5bada1b2a1ce04c3..dedc2133bd12247b1c78f9e8bb25342d90842478 100644 (file)
@@ -19,12 +19,20 @@ class UpdateDeleteTest(fixtures.MappedTest):
                      test_needs_autoincrement=True),
               Column('name', String(32)),
               Column('age_int', Integer))
+        Table(
+            "addresses", metadata,
+            Column('id', Integer, primary_key=True),
+            Column('user_id', ForeignKey('users.id'))
+        )
 
     @classmethod
     def setup_classes(cls):
         class User(cls.Comparable):
             pass
 
+        class Address(cls.Comparable):
+            pass
+
     @classmethod
     def insert_data(cls):
         users = cls.tables.users
@@ -41,9 +49,14 @@ class UpdateDeleteTest(fixtures.MappedTest):
         User = cls.classes.User
         users = cls.tables.users
 
+        Address = cls.classes.Address
+        addresses = cls.tables.addresses
+
         mapper(User, users, properties={
-            'age': users.c.age_int
+            'age': users.c.age_int,
+            'addresses': relationship(Address)
         })
+        mapper(Address, addresses)
 
     def test_illegal_eval(self):
         User = self.classes.User
@@ -59,27 +72,36 @@ class UpdateDeleteTest(fixtures.MappedTest):
 
     def test_illegal_operations(self):
         User = self.classes.User
+        Address = self.classes.Address
 
         s = Session()
 
         for q, mname in (
-            (s.query(User).limit(2), "limit"),
-            (s.query(User).offset(2), "offset"),
-            (s.query(User).limit(2).offset(2), "limit"),
-            (s.query(User).order_by(User.id), "order_by"),
-            (s.query(User).group_by(User.id), "group_by"),
-            (s.query(User).distinct(), "distinct")
+            (s.query(User).limit(2), r"limit\(\)"),
+            (s.query(User).offset(2), r"offset\(\)"),
+            (s.query(User).limit(2).offset(2), r"limit\(\)"),
+            (s.query(User).order_by(User.id), r"order_by\(\)"),
+            (s.query(User).group_by(User.id), r"group_by\(\)"),
+            (s.query(User).distinct(), r"distinct\(\)"),
+            (s.query(User).join(User.addresses),
+                r"join\(\), outerjoin\(\), select_from\(\), or from_self\(\)"),
+            (s.query(User).outerjoin(User.addresses),
+                r"join\(\), outerjoin\(\), select_from\(\), or from_self\(\)"),
+            (s.query(User).select_from(Address),
+                r"join\(\), outerjoin\(\), select_from\(\), or from_self\(\)"),
+            (s.query(User).from_self(),
+                r"join\(\), outerjoin\(\), select_from\(\), or from_self\(\)"),
         ):
             assert_raises_message(
                 exc.InvalidRequestError,
-                r"Can't call Query.update\(\) when "
-                "%s\(\) has been called" % mname,
+                r"Can't call Query.update\(\) or Query.delete\(\) when "
+                "%s has been called" % mname,
                 q.update,
                 {'name': 'ed'})
             assert_raises_message(
                 exc.InvalidRequestError,
-                r"Can't call Query.delete\(\) when "
-                "%s\(\) has been called" % mname,
+                r"Can't call Query.update\(\) or Query.delete\(\) when "
+                "%s has been called" % mname,
                 q.delete)
 
     def test_evaluate_clauseelement(self):