]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Allow SQL expressions to be set on PK columns
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 20 Sep 2016 15:33:16 +0000 (11:33 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 20 Sep 2016 15:33:16 +0000 (11:33 -0400)
Removes an unnecessary transfer of modified PK column
value to the params dictionary, so that if the modified PK column
is already present in value_params, this remains in effect.  Also
propagate a new flag through to _emit_update_statements() that will
trip "return_defaults()" across the board if a PK col w/ SQL expression
change is present, and pull this PK value in _postfetch as well assuming
we're an UPDATE.

Change-Id: I9ae87f964df9ba8faea8e25e96b8327f968e5d1b
Fixes: #3801
doc/build/changelog/changelog_11.rst
lib/sqlalchemy/orm/persistence.py
test/orm/test_naturalpks.py
test/orm/test_versioning.py

index a0970348911e60d784006429616a64802bfcb2f5..6aa5624dd5f8a8cf26a878974b2722f417fb1f78 100644 (file)
 .. changelog::
     :version: 1.1.0
 
+    .. change::
+        :tags: bug, orm
+        :tickets: 3801
+
+        An UPDATE emitted from the ORM flush process can now accommodate a
+        SQL expression element for a column within the primary key of an
+        object, if the target database supports RETURNING in order to provide
+        the new value, or if the PK value is set "to itself" for the purposes
+        of bumping some other trigger / onupdate on the column.
+
     .. change::
         :tags: bug, orm
         :tickets: 3788
index 0b029f46686cbd62ce902de6469b28ced4dbe9c3..56b028375694f865631bb2719781fcf463159bd4 100644 (file)
@@ -506,6 +506,7 @@ def _collect_update_commands(
         elif not (params or value_params):
             continue
 
+        has_all_pks = True
         if bulk:
             pk_params = dict(
                 (propkey_to_col[propkey]._label, state_dict.get(propkey))
@@ -530,7 +531,8 @@ def _collect_update_commands(
                     else:
                         # else, use the old value to locate the row
                         pk_params[col._label] = history.deleted[0]
-                        params[col.key] = history.added[0]
+                        if col in value_params:
+                            has_all_pks = False
                 else:
                     pk_params[col._label] = history.unchanged[0]
                 if pk_params[col._label] is None:
@@ -542,7 +544,7 @@ def _collect_update_commands(
             params.update(pk_params)
             yield (
                 state, state_dict, params, mapper,
-                connection, value_params, has_all_defaults)
+                connection, value_params, has_all_defaults, has_all_pks)
 
 
 def _collect_post_update_commands(base_mapper, uowtransaction, table,
@@ -636,14 +638,15 @@ def _emit_update_statements(base_mapper, uowtransaction,
 
     cached_stmt = base_mapper._memo(('update', table), update_stmt)
 
-    for (connection, paramkeys, hasvalue, has_all_defaults), \
+    for (connection, paramkeys, hasvalue, has_all_defaults, has_all_pks), \
         records in groupby(
             update,
             lambda rec: (
                 rec[4],  # connection
                 set(rec[2]),  # set of parameter keys
                 bool(rec[5]),  # whether or not we have "value" parameters
-                rec[6]  # has_all_defaults
+                rec[6],  # has_all_defaults
+                rec[7]  # has all pks
             )
     ):
         rows = 0
@@ -659,7 +662,9 @@ def _emit_update_statements(base_mapper, uowtransaction,
             connection.dialect.supports_sane_multi_rowcount
         allow_multirow = has_all_defaults and not needs_version_id
 
-        if bookkeeping and not has_all_defaults and \
+        if not has_all_pks:
+            statement = statement.return_defaults()
+        elif bookkeeping and not has_all_defaults and \
                 mapper.base_mapper.eager_defaults:
             statement = statement.return_defaults()
         elif mapper.version_id_col is not None:
@@ -667,7 +672,8 @@ def _emit_update_statements(base_mapper, uowtransaction,
 
         if hasvalue:
             for state, state_dict, params, mapper, \
-                    connection, value_params, has_all_defaults in records:
+                    connection, value_params, \
+                    has_all_defaults, has_all_pks in records:
                 c = connection.execute(
                     statement.values(value_params),
                     params)
@@ -687,7 +693,8 @@ def _emit_update_statements(base_mapper, uowtransaction,
             if not allow_multirow:
                 check_rowcount = assert_singlerow
                 for state, state_dict, params, mapper, \
-                        connection, value_params, has_all_defaults in records:
+                        connection, value_params, has_all_defaults, \
+                        has_all_pks in records:
                     c = cached_connections[connection].\
                         execute(statement, params)
 
@@ -717,7 +724,8 @@ def _emit_update_statements(base_mapper, uowtransaction,
                 rows += c.rowcount
 
                 for state, state_dict, params, mapper, \
-                        connection, value_params, has_all_defaults in records:
+                        connection, value_params, \
+                        has_all_defaults, has_all_pks in records:
                     if bookkeeping:
                         _postfetch(
                             mapper,
@@ -1013,7 +1021,9 @@ def _postfetch(mapper, uowtransaction, table,
         row = result.context.returned_defaults
         if row is not None:
             for col in returning_cols:
-                if col.primary_key:
+                # pk cols returned from insert are handled
+                # distinctly, don't step on the values here
+                if col.primary_key and result.context.isinsert:
                     continue
                 dict_[mapper._columntoproperty[col].key] = row[col]
                 if refresh_flush:
index 60387ddce2b0ac0adcb6b8114e68dc157ff4295a..6780967c9b0ed918b0ebde85eb410cea619536d5 100644 (file)
@@ -122,6 +122,43 @@ class NaturalPKTest(fixtures.MappedTest):
         assert sess.query(User).get('jack') is None
         assert sess.query(User).get('ed').fullname == 'jack'
 
+    @testing.requires.returning
+    def test_update_to_sql_expr(self):
+        users, User = self.tables.users, self.classes.User
+
+        mapper(User, users)
+
+        sess = create_session()
+        u1 = User(username='jack', fullname='jack')
+
+        sess.add(u1)
+        sess.flush()
+
+        u1.username = User.username + ' jones'
+
+        sess.flush()
+
+        eq_(u1.username, 'jack jones')
+
+    def test_update_to_self_sql_expr(self):
+        # SQL expression where the PK won't actually change,
+        # such as to bump a server side trigger
+        users, User = self.tables.users, self.classes.User
+
+        mapper(User, users)
+
+        sess = create_session()
+        u1 = User(username='jack', fullname='jack')
+
+        sess.add(u1)
+        sess.flush()
+
+        u1.username = User.username + ''
+
+        sess.flush()
+
+        eq_(u1.username, 'jack')
+
     def test_flush_new_pk_after_expire(self):
         User, users = self.classes.User, self.tables.users
 
index 07b090c60de91f6622dcc24f8788b82f844e712d..40b37309701f27f302e434c0de8440a7f4266af2 100644 (file)
@@ -1020,6 +1020,38 @@ class ServerVersioningTest(fixtures.MappedTest):
             )
         self.assert_sql_execution(testing.db, sess.flush, *statements)
 
+    def test_sql_expr_bump(self):
+        sess = self._fixture()
+
+        f1 = self.classes.Foo(value='f1')
+        sess.add(f1)
+        sess.flush()
+
+        eq_(f1.version_id, 1)
+
+        f1.id = self.classes.Foo.id + 0
+
+        sess.flush()
+
+        eq_(f1.version_id, 2)
+
+    @testing.requires.returning
+    def test_sql_expr_w_mods_bump(self):
+        sess = self._fixture()
+
+        f1 = self.classes.Foo(id=2, value='f1')
+        sess.add(f1)
+        sess.flush()
+
+        eq_(f1.version_id, 1)
+
+        f1.id = self.classes.Foo.id + 3
+
+        sess.flush()
+
+        eq_(f1.id, 5)
+        eq_(f1.version_id, 2)
+
     def test_multi_update(self):
         sess = self._fixture()