From: Mike Bayer Date: Tue, 20 Sep 2016 15:33:16 +0000 (-0400) Subject: Allow SQL expressions to be set on PK columns X-Git-Tag: rel_1_1_0~24^2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=f8ecdf47f0975b8b4e357fde2008d9aae8c50239;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Allow SQL expressions to be set on PK columns Removes an unnecessary transfer of modified PK column value to the params dictionary, so that if the modified PK column is already present in value_params, this remains in effect. Also propagate a new flag through to _emit_update_statements() that will trip "return_defaults()" across the board if a PK col w/ SQL expression change is present, and pull this PK value in _postfetch as well assuming we're an UPDATE. Change-Id: I9ae87f964df9ba8faea8e25e96b8327f968e5d1b Fixes: #3801 --- diff --git a/doc/build/changelog/changelog_11.rst b/doc/build/changelog/changelog_11.rst index a097034891..6aa5624dd5 100644 --- a/doc/build/changelog/changelog_11.rst +++ b/doc/build/changelog/changelog_11.rst @@ -21,6 +21,16 @@ .. changelog:: :version: 1.1.0 + .. change:: + :tags: bug, orm + :tickets: 3801 + + An UPDATE emitted from the ORM flush process can now accommodate a + SQL expression element for a column within the primary key of an + object, if the target database supports RETURNING in order to provide + the new value, or if the PK value is set "to itself" for the purposes + of bumping some other trigger / onupdate on the column. + .. change:: :tags: bug, orm :tickets: 3788 diff --git a/lib/sqlalchemy/orm/persistence.py b/lib/sqlalchemy/orm/persistence.py index 0b029f4668..56b0283756 100644 --- a/lib/sqlalchemy/orm/persistence.py +++ b/lib/sqlalchemy/orm/persistence.py @@ -506,6 +506,7 @@ def _collect_update_commands( elif not (params or value_params): continue + has_all_pks = True if bulk: pk_params = dict( (propkey_to_col[propkey]._label, state_dict.get(propkey)) @@ -530,7 +531,8 @@ def _collect_update_commands( else: # else, use the old value to locate the row pk_params[col._label] = history.deleted[0] - params[col.key] = history.added[0] + if col in value_params: + has_all_pks = False else: pk_params[col._label] = history.unchanged[0] if pk_params[col._label] is None: @@ -542,7 +544,7 @@ def _collect_update_commands( params.update(pk_params) yield ( state, state_dict, params, mapper, - connection, value_params, has_all_defaults) + connection, value_params, has_all_defaults, has_all_pks) def _collect_post_update_commands(base_mapper, uowtransaction, table, @@ -636,14 +638,15 @@ def _emit_update_statements(base_mapper, uowtransaction, cached_stmt = base_mapper._memo(('update', table), update_stmt) - for (connection, paramkeys, hasvalue, has_all_defaults), \ + for (connection, paramkeys, hasvalue, has_all_defaults, has_all_pks), \ records in groupby( update, lambda rec: ( rec[4], # connection set(rec[2]), # set of parameter keys bool(rec[5]), # whether or not we have "value" parameters - rec[6] # has_all_defaults + rec[6], # has_all_defaults + rec[7] # has all pks ) ): rows = 0 @@ -659,7 +662,9 @@ def _emit_update_statements(base_mapper, uowtransaction, connection.dialect.supports_sane_multi_rowcount allow_multirow = has_all_defaults and not needs_version_id - if bookkeeping and not has_all_defaults and \ + if not has_all_pks: + statement = statement.return_defaults() + elif bookkeeping and not has_all_defaults and \ mapper.base_mapper.eager_defaults: statement = statement.return_defaults() elif mapper.version_id_col is not None: @@ -667,7 +672,8 @@ def _emit_update_statements(base_mapper, uowtransaction, if hasvalue: for state, state_dict, params, mapper, \ - connection, value_params, has_all_defaults in records: + connection, value_params, \ + has_all_defaults, has_all_pks in records: c = connection.execute( statement.values(value_params), params) @@ -687,7 +693,8 @@ def _emit_update_statements(base_mapper, uowtransaction, if not allow_multirow: check_rowcount = assert_singlerow for state, state_dict, params, mapper, \ - connection, value_params, has_all_defaults in records: + connection, value_params, has_all_defaults, \ + has_all_pks in records: c = cached_connections[connection].\ execute(statement, params) @@ -717,7 +724,8 @@ def _emit_update_statements(base_mapper, uowtransaction, rows += c.rowcount for state, state_dict, params, mapper, \ - connection, value_params, has_all_defaults in records: + connection, value_params, \ + has_all_defaults, has_all_pks in records: if bookkeeping: _postfetch( mapper, @@ -1013,7 +1021,9 @@ def _postfetch(mapper, uowtransaction, table, row = result.context.returned_defaults if row is not None: for col in returning_cols: - if col.primary_key: + # pk cols returned from insert are handled + # distinctly, don't step on the values here + if col.primary_key and result.context.isinsert: continue dict_[mapper._columntoproperty[col].key] = row[col] if refresh_flush: diff --git a/test/orm/test_naturalpks.py b/test/orm/test_naturalpks.py index 60387ddce2..6780967c9b 100644 --- a/test/orm/test_naturalpks.py +++ b/test/orm/test_naturalpks.py @@ -122,6 +122,43 @@ class NaturalPKTest(fixtures.MappedTest): assert sess.query(User).get('jack') is None assert sess.query(User).get('ed').fullname == 'jack' + @testing.requires.returning + def test_update_to_sql_expr(self): + users, User = self.tables.users, self.classes.User + + mapper(User, users) + + sess = create_session() + u1 = User(username='jack', fullname='jack') + + sess.add(u1) + sess.flush() + + u1.username = User.username + ' jones' + + sess.flush() + + eq_(u1.username, 'jack jones') + + def test_update_to_self_sql_expr(self): + # SQL expression where the PK won't actually change, + # such as to bump a server side trigger + users, User = self.tables.users, self.classes.User + + mapper(User, users) + + sess = create_session() + u1 = User(username='jack', fullname='jack') + + sess.add(u1) + sess.flush() + + u1.username = User.username + '' + + sess.flush() + + eq_(u1.username, 'jack') + def test_flush_new_pk_after_expire(self): User, users = self.classes.User, self.tables.users diff --git a/test/orm/test_versioning.py b/test/orm/test_versioning.py index 07b090c60d..40b3730970 100644 --- a/test/orm/test_versioning.py +++ b/test/orm/test_versioning.py @@ -1020,6 +1020,38 @@ class ServerVersioningTest(fixtures.MappedTest): ) self.assert_sql_execution(testing.db, sess.flush, *statements) + def test_sql_expr_bump(self): + sess = self._fixture() + + f1 = self.classes.Foo(value='f1') + sess.add(f1) + sess.flush() + + eq_(f1.version_id, 1) + + f1.id = self.classes.Foo.id + 0 + + sess.flush() + + eq_(f1.version_id, 2) + + @testing.requires.returning + def test_sql_expr_w_mods_bump(self): + sess = self._fixture() + + f1 = self.classes.Foo(id=2, value='f1') + sess.add(f1) + sess.flush() + + eq_(f1.version_id, 1) + + f1.id = self.classes.Foo.id + 3 + + sess.flush() + + eq_(f1.id, 5) + eq_(f1.version_id, 2) + def test_multi_update(self): sess = self._fixture()