]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- The :meth:`.Query.update` method will now convert string key
authorMike Bayer <mike_mp@zzzcomputing.com>
Thu, 16 Oct 2014 18:36:56 +0000 (14:36 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Thu, 16 Oct 2014 18:36:56 +0000 (14:36 -0400)
names in the given dictionary of values into mapped attribute names
against the mapped class being updated.  Previously, string names
were taken in directly and passed to the core update statement without
any means to resolve against the mapped entity.  Support for synonyms
and hybrid attributes as the subject attributes of
:meth:`.Query.update` are also supported.
fixes #3228

doc/build/changelog/changelog_10.rst
doc/build/changelog/migration_10.rst
lib/sqlalchemy/orm/persistence.py
lib/sqlalchemy/orm/query.py
test/orm/test_update_delete.py

index 66fa2ad267eee3ec6dbe21f011b2a7b653e3c67b..ec812a091691127f4d94fe7eae5bd6098a2db182 100644 (file)
     series as well.  For changes that are specific to 1.0 with an emphasis
     on compatibility concerns, see :doc:`/changelog/migration_10`.
 
+    .. change::
+        :tags: bug, orm
+        :tickets: 3228
+
+        The :meth:`.Query.update` method will now convert string key
+        names in the given dictionary of values into mapped attribute names
+        against the mapped class being updated.  Previously, string names
+        were taken in directly and passed to the core update statement without
+        any means to resolve against the mapped entity.  Support for synonyms
+        and hybrid attributes as the subject attributes of
+        :meth:`.Query.update` are also supported.
+
+        .. seealso::
+
+            :ref:`bug_3228`
+
     .. change::
         :tags: bug, orm
         :tickets: 3035
index dd8964f8bbb88cfd61752030afa9fedfc5d57cad..3591ee0e2444cea9d61dd0753f6bb0e50100bfc3 100644 (file)
@@ -510,6 +510,7 @@ of inheritance-oriented scenarios, including:
 
 :ticket:`3035`
 
+
 .. _feature_3178:
 
 New systems to safely emit parameterized warnings
@@ -793,6 +794,62 @@ would again fail; these have also been fixed.
 Behavioral Changes - ORM
 ========================
 
+.. _bug_3228:
+
+query.update() now resolves string names into mapped attribute names
+--------------------------------------------------------------------
+
+The documentation for :meth:`.Query.update` states that the given
+``values`` dictionary is "a dictionary with attributes names as keys",
+implying that these are mapped attribute names.  Unfortunately, the function
+was designed more in mind to receive attributes and SQL expressions and
+not as much strings; when strings
+were passed, these strings would be passed through straight to the core
+update statement without any resolution as far as how these names are
+represented on the mapped class, meaning the name would have to match that
+of a table column exactly, not how an attribute of that name was mapped
+onto the class.
+
+The string names are now resolved as attribute names in earnest::
+
+    class User(Base):
+        __tablename__ = 'user'
+
+        id = Column(Integer, primary_key=True)
+        name = Column('user_name', String(50))
+
+Above, the column ``user_name`` is mapped as ``name``.  Previously,
+a call to :meth:`.Query.update` that was passed strings would have to
+have been called as follows::
+
+    session.query(User).update({'user_name': 'moonbeam'})
+
+The given string is now resolved against the entity::
+
+    session.query(User).update({'name': 'moonbeam'})
+
+It is typically preferable to use the attribute directly, to avoid any
+ambiguity::
+
+    session.query(User).update({User.name: 'moonbeam'})
+
+The change also indicates that synonyms and hybrid attributes can be referred
+to by string name as well::
+
+    class User(Base):
+        __tablename__ = 'user'
+
+        id = Column(Integer, primary_key=True)
+        name = Column('user_name', String(50))
+
+        @hybrid_property
+        def fullname(self):
+            return self.name
+
+    session.query(User).update({'fullname': 'moonbeam'})
+
+:ticket:`3228`
+
 .. _migration_3061:
 
 Changes to attribute events and other operations regarding attributes that have no pre-existing value
index 74e69e44c51e19a7e5186a01af36bee9d1204651..114b79ea543659b77ee8bf34a0e21d5fd8aeb3fc 100644 (file)
@@ -18,7 +18,7 @@ import operator
 from itertools import groupby
 from .. import sql, util, exc as sa_exc, schema
 from . import attributes, sync, exc as orm_exc, evaluator
-from .base import state_str, _attr_as_key
+from .base import state_str, _attr_as_key, _entity_descriptor
 from ..sql import expression
 from . import loading
 
@@ -987,6 +987,7 @@ class BulkUpdate(BulkUD):
         super(BulkUpdate, self).__init__(query)
         self.query._no_select_modifiers("update")
         self.values = values
+        self.mapper = self.query._mapper_zero_or_none()
 
     @classmethod
     def factory(cls, query, synchronize_session, values):
@@ -996,9 +997,40 @@ class BulkUpdate(BulkUD):
             False: BulkUpdate
         }, synchronize_session, query, values)
 
+    def _resolve_string_to_expr(self, key):
+        if self.mapper and isinstance(key, util.string_types):
+            attr = _entity_descriptor(self.mapper, key)
+            return attr.__clause_element__()
+        else:
+            return key
+
+    def _resolve_key_to_attrname(self, key):
+        if self.mapper and isinstance(key, util.string_types):
+            attr = _entity_descriptor(self.mapper, key)
+            return attr.property.key
+        elif isinstance(key, attributes.InstrumentedAttribute):
+            return key.key
+        elif hasattr(key, '__clause_element__'):
+            key = key.__clause_element__()
+
+        if self.mapper and isinstance(key, expression.ColumnElement):
+            try:
+                attr = self.mapper._columntoproperty[key]
+            except orm_exc.UnmappedColumnError:
+                return None
+            else:
+                return attr.key
+        else:
+            raise sa_exc.InvalidRequestError(
+                "Invalid expression type: %r" % key)
+
     def _do_exec(self):
+        values = dict(
+            (self._resolve_string_to_expr(k), v)
+            for k, v in self.values.items()
+        )
         update_stmt = sql.update(self.primary_table,
-                                 self.context.whereclause, self.values)
+                                 self.context.whereclause, values)
 
         self.result = self.query.session.execute(
             update_stmt, params=self.query._params)
@@ -1044,9 +1076,10 @@ class BulkUpdateEvaluate(BulkEvaluate, BulkUpdate):
     def _additional_evaluators(self, evaluator_compiler):
         self.value_evaluators = {}
         for key, value in self.values.items():
-            key = _attr_as_key(key)
-            self.value_evaluators[key] = evaluator_compiler.process(
-                expression._literal_as_binds(value))
+            key = self._resolve_key_to_attrname(key)
+            if key is not None:
+                self.value_evaluators[key] = evaluator_compiler.process(
+                    expression._literal_as_binds(value))
 
     def _do_post_synchronize(self):
         session = self.query.session
index 7b2ea7977e4a19a077786c59ddc7a2ecdcc64b43..fce7a3665c316d931d78db2d07a8b049fde6e6b8 100644 (file)
@@ -2756,9 +2756,25 @@ class Query(object):
 
         Updates rows matched by this query in the database.
 
-        :param values: a dictionary with attributes names as keys and literal
+        E.g.::
+
+            sess.query(User).filter(User.age == 25).\
+                update({User.age: User.age - 10}, synchronize_session='fetch')
+
+
+            sess.query(User).filter(User.age == 25).\
+                update({"age": User.age - 10}, synchronize_session='evaluate')
+
+
+        :param values: a dictionary with attributes names, or alternatively
+          mapped attributes or SQL expressions, as keys, and literal
           values or sql expressions as values.
 
+          .. versionchanged:: 1.0.0 - string names in the values dictionary
+             are now resolved against the mapped entity; previously, these
+             strings were passed as literal column names with no mapper-level
+             translation.
+
         :param synchronize_session: chooses the strategy to update the
             attributes on objects in the session. Valid values are:
 
@@ -2796,7 +2812,7 @@ class Query(object):
           which normally occurs upon :meth:`.Session.commit` or can be forced
           by using :meth:`.Session.expire_all`.
 
-        * As of 0.8, this method will support multiple table updates, as
+        * The method supports multiple table updates, as
           detailed in :ref:`multi_table_updates`, and this behavior does
           extend to support updates of joined-inheritance and other multiple
           table mappings.  However, the **join condition of an inheritance
@@ -2827,12 +2843,6 @@ class Query(object):
 
         """
 
-        # TODO: value keys need to be mapped to corresponding sql cols and
-        # instr.attr.s to string keys
-        # TODO: updates of manytoone relationships need to be converted to
-        # fk assignments
-        # TODO: cascades need handling.
-
         update_op = persistence.BulkUpdate.factory(
             self, synchronize_session, values)
         update_op.exec_()
index a737a2e1d7e24e7ca225635d7952f3da3166535b..a3ad37e60573d01807da161d5bada1b2a1ce04c3 100644 (file)
@@ -1,9 +1,9 @@
 from sqlalchemy.testing import eq_, assert_raises, assert_raises_message
 from sqlalchemy.testing import fixtures
 from sqlalchemy import Integer, String, ForeignKey, or_, exc, \
-    select, func, Boolean, case, text
+    select, func, Boolean, case, text, column
 from sqlalchemy.orm import mapper, relationship, backref, Session, \
-    joinedload
+    joinedload, synonym
 from sqlalchemy import testing
 
 from sqlalchemy.testing.schema import Table, Column
@@ -18,7 +18,7 @@ class UpdateDeleteTest(fixtures.MappedTest):
               Column('id', Integer, primary_key=True,
                      test_needs_autoincrement=True),
               Column('name', String(32)),
-              Column('age', Integer))
+              Column('age_int', Integer))
 
     @classmethod
     def setup_classes(cls):
@@ -30,10 +30,10 @@ class UpdateDeleteTest(fixtures.MappedTest):
         users = cls.tables.users
 
         users.insert().execute([
-            dict(id=1, name='john', age=25),
-            dict(id=2, name='jack', age=47),
-            dict(id=3, name='jill', age=29),
-            dict(id=4, name='jane', age=37),
+            dict(id=1, name='john', age_int=25),
+            dict(id=2, name='jack', age_int=47),
+            dict(id=3, name='jill', age_int=29),
+            dict(id=4, name='jane', age_int=37),
         ])
 
     @classmethod
@@ -41,7 +41,9 @@ class UpdateDeleteTest(fixtures.MappedTest):
         User = cls.classes.User
         users = cls.tables.users
 
-        mapper(User, users)
+        mapper(User, users, properties={
+            'age': users.c.age_int
+        })
 
     def test_illegal_eval(self):
         User = self.classes.User
@@ -80,6 +82,108 @@ class UpdateDeleteTest(fixtures.MappedTest):
                 "%s\(\) has been called" % mname,
                 q.delete)
 
+    def test_evaluate_clauseelement(self):
+        User = self.classes.User
+
+        class Thing(object):
+            def __clause_element__(self):
+                return User.name.__clause_element__()
+
+        s = Session()
+        jill = s.query(User).get(3)
+        s.query(User).update(
+            {Thing(): 'moonbeam'},
+            synchronize_session='evaluate')
+        eq_(jill.name, 'moonbeam')
+
+    def test_evaluate_invalid(self):
+        User = self.classes.User
+
+        class Thing(object):
+            def __clause_element__(self):
+                return 5
+
+        s = Session()
+
+        assert_raises_message(
+            exc.InvalidRequestError,
+            "Invalid expression type: 5",
+            s.query(User).update, {Thing(): 'moonbeam'},
+            synchronize_session='evaluate'
+        )
+
+    def test_evaluate_unmapped_col(self):
+        User = self.classes.User
+
+        s = Session()
+        jill = s.query(User).get(3)
+        s.query(User).update(
+            {column('name'): 'moonbeam'},
+            synchronize_session='evaluate')
+        eq_(jill.name, 'jill')
+        s.expire(jill)
+        eq_(jill.name, 'moonbeam')
+
+    def test_evaluate_synonym_string(self):
+        class Foo(object):
+            pass
+        mapper(Foo, self.tables.users, properties={
+            'uname': synonym("name", )
+        })
+
+        s = Session()
+        jill = s.query(Foo).get(3)
+        s.query(Foo).update(
+            {'uname': 'moonbeam'},
+            synchronize_session='evaluate')
+        eq_(jill.uname, 'moonbeam')
+
+    def test_evaluate_synonym_attr(self):
+        class Foo(object):
+            pass
+        mapper(Foo, self.tables.users, properties={
+            'uname': synonym("name", )
+        })
+
+        s = Session()
+        jill = s.query(Foo).get(3)
+        s.query(Foo).update(
+            {Foo.uname: 'moonbeam'},
+            synchronize_session='evaluate')
+        eq_(jill.uname, 'moonbeam')
+
+    def test_evaluate_double_synonym_attr(self):
+        class Foo(object):
+            pass
+        mapper(Foo, self.tables.users, properties={
+            'uname': synonym("name"),
+            'ufoo': synonym('uname')
+        })
+
+        s = Session()
+        jill = s.query(Foo).get(3)
+        s.query(Foo).update(
+            {Foo.ufoo: 'moonbeam'},
+            synchronize_session='evaluate')
+        eq_(jill.ufoo, 'moonbeam')
+
+    def test_evaluate_hybrid_attr(self):
+        from sqlalchemy.ext.hybrid import hybrid_property
+
+        class Foo(object):
+            @hybrid_property
+            def uname(self):
+                return self.name
+
+        mapper(Foo, self.tables.users)
+
+        s = Session()
+        jill = s.query(Foo).get(3)
+        s.query(Foo).update(
+            {Foo.uname: 'moonbeam'},
+            synchronize_session='evaluate')
+        eq_(jill.uname, 'moonbeam')
+
     def test_delete(self):
         User = self.classes.User
 
@@ -208,7 +312,8 @@ class UpdateDeleteTest(fixtures.MappedTest):
 
         sess.query(User).filter(User.age > 27).\
             update(
-                {users.c.age: User.age - 10}, synchronize_session='evaluate')
+                {users.c.age_int: User.age - 10},
+                synchronize_session='evaluate')
         eq_([john.age, jack.age, jill.age, jane.age], [25, 27, 19, 27])
         eq_(sess.query(User.age).order_by(
             User.id).all(), list(zip([25, 27, 19, 27])))
@@ -219,12 +324,25 @@ class UpdateDeleteTest(fixtures.MappedTest):
         eq_(sess.query(User.age).order_by(
             User.id).all(), list(zip([15, 27, 19, 27])))
 
+    def test_update_against_table_col(self):
+        User, users = self.classes.User, self.tables.users
+
+        sess = Session()
+        john, jack, jill, jane = sess.query(User).order_by(User.id).all()
+        eq_([john.age, jack.age, jill.age, jane.age], [25, 47, 29, 37])
+        sess.query(User).filter(User.age > 27).\
+            update(
+                {users.c.age_int: User.age - 10},
+                synchronize_session='evaluate')
+        eq_([john.age, jack.age, jill.age, jane.age], [25, 37, 19, 27])
+
     def test_update_against_metadata(self):
         User, users = self.classes.User, self.tables.users
 
         sess = Session()
 
-        sess.query(users).update({users.c.age: 29}, synchronize_session=False)
+        sess.query(users).update(
+            {users.c.age_int: 29}, synchronize_session=False)
         eq_(sess.query(User.age).order_by(
             User.id).all(), list(zip([29, 29, 29, 29])))
 
@@ -235,7 +353,7 @@ class UpdateDeleteTest(fixtures.MappedTest):
 
         john, jack, jill, jane = sess.query(User).order_by(User.id).all()
 
-        sess.query(User).filter(text('age > :x')).params(x=29).\
+        sess.query(User).filter(text('age_int > :x')).params(x=29).\
             update({'age': User.age - 10}, synchronize_session='fetch')
 
         eq_([john.age, jack.age, jill.age, jane.age], [25, 37, 29, 27])