From: Mike Bayer Date: Mon, 14 Dec 2015 22:24:47 +0000 (-0500) Subject: - Fixed bug in :meth:`.Update.return_defaults` which would cause all X-Git-Tag: rel_1_1_0b1~84^2~77^2~8 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=0e4c4d7efc08d04c3c0ae960428b08ada37e4a91;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git - Fixed bug in :meth:`.Update.return_defaults` which would cause all insert-default holding columns not otherwise included in the SET clause (such as primary key cols) to get rendered into the RETURNING even though this is an UPDATE. - Major fixes to the :paramref:`.Mapper.eager_defaults` flag, this flag would not be honored correctly in the case that multiple UPDATE statements were to be emitted, either as part of a flush or a bulk update operation. Additionally, RETURNING would be emitted unnecessarily within update statements. fixes #3609 --- diff --git a/doc/build/changelog/changelog_10.rst b/doc/build/changelog/changelog_10.rst index 950046cd0e..974aa5f1a1 100644 --- a/doc/build/changelog/changelog_10.rst +++ b/doc/build/changelog/changelog_10.rst @@ -18,6 +18,27 @@ .. changelog:: :version: 1.0.11 + .. change:: + :tags: bug, sql + :tickets: 3609 + :versions: 1.1.0b1 + + Fixed bug in :meth:`.Update.return_defaults` which would cause all + insert-default holding columns not otherwise included in the SET + clause (such as primary key cols) to get rendered into the RETURNING + even though this is an UPDATE. + + .. change:: + :tags: bug, orm + :tickets: 3609 + :versions: 1.1.0b1 + + Major fixes to the :paramref:`.Mapper.eager_defaults` flag, this + flag would not be honored correctly in the case that multiple + UPDATE statements were to be emitted, either as part of a flush + or a bulk update operation. Additionally, RETURNING + would be emitted unnecessarily within update statements. + .. change:: :tags: bug, orm :tickets: 3606 diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index 5ade4b9665..95aa14a26f 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -1970,12 +1970,24 @@ class Mapper(InspectionAttr): ( table, frozenset([ - col for col in columns + col.key for col in columns if col.server_default is not None]) ) for table, columns in self._cols_by_table.items() ) + @_memoized_configured_property + def _server_onupdate_default_cols(self): + return dict( + ( + table, + frozenset([ + col.key for col in columns + if col.server_onupdate is not None]) + ) + for table, columns in self._cols_by_table.items() + ) + @property def selectable(self): """The :func:`.select` construct this :class:`.Mapper` selects from diff --git a/lib/sqlalchemy/orm/persistence.py b/lib/sqlalchemy/orm/persistence.py index 768c1146a3..88c96e94ce 100644 --- a/lib/sqlalchemy/orm/persistence.py +++ b/lib/sqlalchemy/orm/persistence.py @@ -448,6 +448,7 @@ def _collect_update_commands( set(propkey_to_col).intersection(state_dict).difference( mapper._pk_keys_by_table[table]) ) + has_all_defaults = True else: params = {} for propkey in set(propkey_to_col).intersection( @@ -463,6 +464,12 @@ def _collect_update_commands( value, state.committed_state[propkey]) is not True: params[col.key] = value + if mapper.base_mapper.eager_defaults: + has_all_defaults = mapper._server_onupdate_default_cols[table].\ + issubset(params) + else: + has_all_defaults = True + if update_version_id is not None and \ mapper.version_id_col in mapper._cols_by_table[table]: @@ -529,7 +536,7 @@ def _collect_update_commands( params.update(pk_params) yield ( state, state_dict, params, mapper, - connection, value_params) + connection, value_params, has_all_defaults) def _collect_post_update_commands(base_mapper, uowtransaction, table, @@ -619,23 +626,20 @@ def _emit_update_statements(base_mapper, uowtransaction, type_=mapper.version_id_col.type)) stmt = table.update(clause) - if mapper.base_mapper.eager_defaults: - stmt = stmt.return_defaults() - elif mapper.version_id_col is not None: - stmt = stmt.return_defaults(mapper.version_id_col) - return stmt statement = base_mapper._memo(('update', table), update_stmt) - for (connection, paramkeys, hasvalue), \ + for (connection, paramkeys, hasvalue, has_all_defaults), \ records in groupby( update, lambda rec: ( rec[4], # connection set(rec[2]), # set of parameter keys - bool(rec[5]))): # whether or not we have "value" parameters - + bool(rec[5]), # whether or not we have "value" parameters + rec[6] # has_all_defaults + ) + ): rows = 0 records = list(records) @@ -645,11 +649,16 @@ def _emit_update_statements(base_mapper, uowtransaction, assert_singlerow = connection.dialect.supports_sane_rowcount assert_multirow = assert_singlerow and \ connection.dialect.supports_sane_multi_rowcount - allow_multirow = not needs_version_id + allow_multirow = has_all_defaults and not needs_version_id + + if bookkeeping and mapper.base_mapper.eager_defaults: + statement = statement.return_defaults() + elif mapper.version_id_col is not None: + statement = statement.return_defaults(mapper.version_id_col) if hasvalue: for state, state_dict, params, mapper, \ - connection, value_params in records: + connection, value_params, has_all_defaults in records: c = connection.execute( statement.values(value_params), params) @@ -669,7 +678,7 @@ def _emit_update_statements(base_mapper, uowtransaction, if not allow_multirow: check_rowcount = assert_singlerow for state, state_dict, params, mapper, \ - connection, value_params in records: + connection, value_params, has_all_defaults in records: c = cached_connections[connection].\ execute(statement, params) @@ -699,7 +708,7 @@ def _emit_update_statements(base_mapper, uowtransaction, rows += c.rowcount for state, state_dict, params, mapper, \ - connection, value_params in records: + connection, value_params, has_all_defaults in records: if bookkeeping: _postfetch( mapper, @@ -741,6 +750,7 @@ def _emit_insert_statements(base_mapper, uowtransaction, bool(rec[5]), # whether we have "value" parameters rec[6], rec[7])): + if not bookkeeping or \ ( has_all_defaults diff --git a/lib/sqlalchemy/sql/crud.py b/lib/sqlalchemy/sql/crud.py index 18b96018d4..c5495ccde5 100644 --- a/lib/sqlalchemy/sql/crud.py +++ b/lib/sqlalchemy/sql/crud.py @@ -493,6 +493,7 @@ def _append_param_update( else: compiler.postfetch.append(c) elif implicit_return_defaults and \ + stmt._return_defaults is not True and \ c in implicit_return_defaults: compiler.returning.append(c) diff --git a/lib/sqlalchemy/testing/assertsql.py b/lib/sqlalchemy/testing/assertsql.py index 2434936071..39d0789855 100644 --- a/lib/sqlalchemy/testing/assertsql.py +++ b/lib/sqlalchemy/testing/assertsql.py @@ -13,6 +13,7 @@ import contextlib from .. import event from sqlalchemy.schema import _DDLCompiles from sqlalchemy.engine.util import _distill_params +from sqlalchemy.engine import url class AssertRule(object): @@ -58,16 +59,25 @@ class CursorSQL(SQLMatchRule): class CompiledSQL(SQLMatchRule): - def __init__(self, statement, params=None): + def __init__(self, statement, params=None, dialect='default'): self.statement = statement self.params = params + self.dialect = dialect def _compare_sql(self, execute_observed, received_statement): stmt = re.sub(r'[\n\t]', '', self.statement) return received_statement == stmt def _compile_dialect(self, execute_observed): - return DefaultDialect() + if self.dialect == 'default': + return DefaultDialect() + else: + # ugh + if self.dialect == 'postgresql': + params = {'implicit_returning': True} + else: + params = {} + return url.URL(self.dialect).get_dialect()(**params) def _received_statement(self, execute_observed): """reconstruct the statement and params in terms @@ -159,7 +169,7 @@ class CompiledSQL(SQLMatchRule): 'Testing for compiled statement %r partial params %r, ' 'received %%(received_statement)r with params ' '%%(received_parameters)r' % ( - self.statement, expected_params + self.statement.replace('%', '%%'), expected_params ) ) @@ -170,6 +180,7 @@ class RegexSQL(CompiledSQL): self.regex = re.compile(regex) self.orig_regex = regex self.params = params + self.dialect = 'default' def _failure_message(self, expected_params): return ( diff --git a/test/orm/test_unitofworkv2.py b/test/orm/test_unitofworkv2.py index 09240dfdb6..c8ce13c913 100644 --- a/test/orm/test_unitofworkv2.py +++ b/test/orm/test_unitofworkv2.py @@ -5,7 +5,8 @@ from sqlalchemy.testing.schema import Table, Column from test.orm import _fixtures from sqlalchemy import exc, util from sqlalchemy.testing import fixtures, config -from sqlalchemy import Integer, String, ForeignKey, func, literal +from sqlalchemy import Integer, String, ForeignKey, func, \ + literal, FetchedValue, text from sqlalchemy.orm import mapper, relationship, backref, \ create_session, unitofwork, attributes,\ Session, exc as orm_exc @@ -1848,6 +1849,450 @@ class NoAttrEventInFlushTest(fixtures.MappedTest): eq_(t1.returning_val, 5) +class EagerDefaultsTest(fixtures.MappedTest): + __backend__ = True + + @classmethod + def define_tables(cls, metadata): + Table( + 'test', metadata, + Column('id', Integer, primary_key=True), + Column('foo', Integer, server_default="3") + ) + + Table( + 'test2', metadata, + Column('id', Integer, primary_key=True), + Column('foo', Integer), + Column('bar', Integer, server_onupdate=FetchedValue()) + ) + + @classmethod + def setup_classes(cls): + class Thing(cls.Basic): + pass + + class Thing2(cls.Basic): + pass + + @classmethod + def setup_mappers(cls): + Thing = cls.classes.Thing + + mapper(Thing, cls.tables.test, eager_defaults=True) + + Thing2 = cls.classes.Thing2 + + mapper(Thing2, cls.tables.test2, eager_defaults=True) + + def test_insert_defaults_present(self): + Thing = self.classes.Thing + s = Session() + + t1, t2 = ( + Thing(id=1, foo=5), + Thing(id=2, foo=10) + ) + + s.add_all([t1, t2]) + + self.assert_sql_execution( + testing.db, + s.flush, + CompiledSQL( + "INSERT INTO test (id, foo) VALUES (:id, :foo)", + [{'foo': 5, 'id': 1}, {'foo': 10, 'id': 2}] + ), + ) + + def go(): + eq_(t1.foo, 5) + eq_(t2.foo, 10) + + self.assert_sql_count(testing.db, go, 0) + + def test_insert_defaults_present_as_expr(self): + Thing = self.classes.Thing + s = Session() + + t1, t2 = ( + Thing(id=1, foo=text("2 + 5")), + Thing(id=2, foo=text("5 + 5")) + ) + + s.add_all([t1, t2]) + + if testing.db.dialect.implicit_returning: + self.assert_sql_execution( + testing.db, + s.flush, + CompiledSQL( + "INSERT INTO test (id, foo) VALUES (%(id)s, 2 + 5) " + "RETURNING test.foo", + [{'id': 1}], + dialect='postgresql' + ), + CompiledSQL( + "INSERT INTO test (id, foo) VALUES (%(id)s, 5 + 5) " + "RETURNING test.foo", + [{'id': 2}], + dialect='postgresql' + ) + ) + + else: + self.assert_sql_execution( + testing.db, + s.flush, + CompiledSQL( + "INSERT INTO test (id, foo) VALUES (:id, 2 + 5)", + [{'id': 1}] + ), + CompiledSQL( + "INSERT INTO test (id, foo) VALUES (:id, 5 + 5)", + [{'id': 2}] + ), + CompiledSQL( + "SELECT test.foo AS test_foo FROM test " + "WHERE test.id = :param_1", + [{'param_1': 1}] + ), + CompiledSQL( + "SELECT test.foo AS test_foo FROM test " + "WHERE test.id = :param_1", + [{'param_1': 2}] + ), + ) + + def go(): + eq_(t1.foo, 7) + eq_(t2.foo, 10) + + self.assert_sql_count(testing.db, go, 0) + + def test_insert_defaults_nonpresent(self): + Thing = self.classes.Thing + s = Session() + + t1, t2 = ( + Thing(id=1), + Thing(id=2) + ) + + s.add_all([t1, t2]) + + if testing.db.dialect.implicit_returning: + self.assert_sql_execution( + testing.db, + s.commit, + CompiledSQL( + "INSERT INTO test (id) VALUES (%(id)s) RETURNING test.foo", + [{'id': 1}], + dialect='postgresql' + ), + CompiledSQL( + "INSERT INTO test (id) VALUES (%(id)s) RETURNING test.foo", + [{'id': 2}], + dialect='postgresql' + ), + ) + else: + self.assert_sql_execution( + testing.db, + s.commit, + CompiledSQL( + "INSERT INTO test (id) VALUES (:id)", + [{'id': 1}, {'id': 2}] + ), + CompiledSQL( + "SELECT test.foo AS test_foo FROM test " + "WHERE test.id = :param_1", + [{'param_1': 1}] + ), + CompiledSQL( + "SELECT test.foo AS test_foo FROM test " + "WHERE test.id = :param_1", + [{'param_1': 2}] + ) + ) + + def test_update_defaults_nonpresent(self): + Thing2 = self.classes.Thing2 + s = Session() + + t1, t2, t3, t4 = ( + Thing2(id=1, foo=1, bar=2), + Thing2(id=2, foo=2, bar=3), + Thing2(id=3, foo=3, bar=4), + Thing2(id=4, foo=4, bar=5) + ) + + s.add_all([t1, t2, t3, t4]) + s.flush() + + t1.foo = 5 + t2.foo = 6 + t2.bar = 10 + t3.foo = 7 + t4.foo = 8 + t4.bar = 12 + + if testing.db.dialect.implicit_returning: + self.assert_sql_execution( + testing.db, + s.flush, + CompiledSQL( + "UPDATE test2 SET foo=%(foo)s " + "WHERE test2.id = %(test2_id)s " + "RETURNING test2.bar", + [{'foo': 5, 'test2_id': 1}], + dialect='postgresql' + ), + CompiledSQL( + "UPDATE test2 SET foo=%(foo)s, bar=%(bar)s " + "WHERE test2.id = %(test2_id)s", + [{'foo': 6, 'bar': 10, 'test2_id': 2}], + dialect='postgresql' + ), + CompiledSQL( + "UPDATE test2 SET foo=%(foo)s " + "WHERE test2.id = %(test2_id)s " + "RETURNING test2.bar", + [{'foo': 7, 'test2_id': 3}], + dialect='postgresql' + ), + CompiledSQL( + "UPDATE test2 SET foo=%(foo)s, bar=%(bar)s " + "WHERE test2.id = %(test2_id)s", + [{'foo': 8, 'bar': 12, 'test2_id': 4}], + dialect='postgresql' + ), + ) + else: + self.assert_sql_execution( + testing.db, + s.flush, + CompiledSQL( + "UPDATE test2 SET foo=:foo WHERE test2.id = :test2_id", + [{'foo': 5, 'test2_id': 1}] + ), + CompiledSQL( + "UPDATE test2 SET foo=:foo, bar=:bar " + "WHERE test2.id = :test2_id", + [{'foo': 6, 'bar': 10, 'test2_id': 2}], + ), + CompiledSQL( + "UPDATE test2 SET foo=:foo WHERE test2.id = :test2_id", + [{'foo': 7, 'test2_id': 3}] + ), + CompiledSQL( + "UPDATE test2 SET foo=:foo, bar=:bar " + "WHERE test2.id = :test2_id", + [{'foo': 8, 'bar': 12, 'test2_id': 4}], + ), + CompiledSQL( + "SELECT test2.bar AS test2_bar FROM test2 " + "WHERE test2.id = :param_1", + [{'param_1': 1}] + ), + CompiledSQL( + "SELECT test2.bar AS test2_bar FROM test2 " + "WHERE test2.id = :param_1", + [{'param_1': 3}] + ) + ) + + def go(): + eq_(t1.bar, 2) + eq_(t2.bar, 10) + eq_(t3.bar, 4) + eq_(t4.bar, 12) + + self.assert_sql_count(testing.db, go, 0) + + def test_update_defaults_present_as_expr(self): + Thing2 = self.classes.Thing2 + s = Session() + + t1, t2, t3, t4 = ( + Thing2(id=1, foo=1, bar=2), + Thing2(id=2, foo=2, bar=3), + Thing2(id=3, foo=3, bar=4), + Thing2(id=4, foo=4, bar=5) + ) + + s.add_all([t1, t2, t3, t4]) + s.flush() + + t1.foo = 5 + t1.bar = text("1 + 1") + t2.foo = 6 + t2.bar = 10 + t3.foo = 7 + t4.foo = 8 + t4.bar = text("5 + 7") + + if testing.db.dialect.implicit_returning: + self.assert_sql_execution( + testing.db, + s.flush, + CompiledSQL( + "UPDATE test2 SET foo=%(foo)s, bar=1 + 1 " + "WHERE test2.id = %(test2_id)s " + "RETURNING test2.bar", + [{'foo': 5, 'test2_id': 1}], + dialect='postgresql' + ), + CompiledSQL( + "UPDATE test2 SET foo=%(foo)s, bar=%(bar)s " + "WHERE test2.id = %(test2_id)s", + [{'foo': 6, 'bar': 10, 'test2_id': 2}], + dialect='postgresql' + ), + CompiledSQL( + "UPDATE test2 SET foo=%(foo)s " + "WHERE test2.id = %(test2_id)s " + "RETURNING test2.bar", + [{'foo': 7, 'test2_id': 3}], + dialect='postgresql' + ), + CompiledSQL( + "UPDATE test2 SET foo=%(foo)s, bar=5 + 7 " + "WHERE test2.id = %(test2_id)s RETURNING test2.bar", + [{'foo': 8, 'test2_id': 4}], + dialect='postgresql' + ), + ) + else: + self.assert_sql_execution( + testing.db, + s.flush, + CompiledSQL( + "UPDATE test2 SET foo=:foo, bar=1 + 1 " + "WHERE test2.id = :test2_id", + [{'foo': 5, 'test2_id': 1}] + ), + CompiledSQL( + "UPDATE test2 SET foo=:foo, bar=:bar " + "WHERE test2.id = :test2_id", + [{'foo': 6, 'bar': 10, 'test2_id': 2}], + ), + CompiledSQL( + "UPDATE test2 SET foo=:foo WHERE test2.id = :test2_id", + [{'foo': 7, 'test2_id': 3}] + ), + CompiledSQL( + "UPDATE test2 SET foo=:foo, bar=5 + 7 " + "WHERE test2.id = :test2_id", + [{'foo': 8, 'test2_id': 4}], + ), + CompiledSQL( + "SELECT test2.bar AS test2_bar FROM test2 " + "WHERE test2.id = :param_1", + [{'param_1': 1}] + ), + CompiledSQL( + "SELECT test2.bar AS test2_bar FROM test2 " + "WHERE test2.id = :param_1", + [{'param_1': 3}] + ), + CompiledSQL( + "SELECT test2.bar AS test2_bar FROM test2 " + "WHERE test2.id = :param_1", + [{'param_1': 4}] + ) + ) + + def go(): + eq_(t1.bar, 2) + eq_(t2.bar, 10) + eq_(t3.bar, 4) + eq_(t4.bar, 12) + + self.assert_sql_count(testing.db, go, 0) + + def test_insert_defaults_bulk_insert(self): + Thing = self.classes.Thing + s = Session() + + mappings = [ + {"id": 1}, + {"id": 2} + ] + + self.assert_sql_execution( + testing.db, + lambda: s.bulk_insert_mappings(Thing, mappings), + CompiledSQL( + "INSERT INTO test (id) VALUES (:id)", + [{'id': 1}, {'id': 2}] + ) + ) + + def test_update_defaults_bulk_update(self): + Thing2 = self.classes.Thing2 + s = Session() + + t1, t2, t3, t4 = ( + Thing2(id=1, foo=1, bar=2), + Thing2(id=2, foo=2, bar=3), + Thing2(id=3, foo=3, bar=4), + Thing2(id=4, foo=4, bar=5) + ) + + s.add_all([t1, t2, t3, t4]) + s.flush() + + mappings = [ + {"id": 1, "foo": 5}, + {"id": 2, "foo": 6, "bar": 10}, + {"id": 3, "foo": 7}, + {"id": 4, "foo": 8} + ] + + self.assert_sql_execution( + testing.db, + lambda: s.bulk_update_mappings(Thing2, mappings), + CompiledSQL( + "UPDATE test2 SET foo=:foo WHERE test2.id = :test2_id", + [{'foo': 5, 'test2_id': 1}] + ), + CompiledSQL( + "UPDATE test2 SET foo=:foo, bar=:bar " + "WHERE test2.id = :test2_id", + [{'foo': 6, 'bar': 10, 'test2_id': 2}] + ), + CompiledSQL( + "UPDATE test2 SET foo=:foo WHERE test2.id = :test2_id", + [{'foo': 7, 'test2_id': 3}, {'foo': 8, 'test2_id': 4}] + ) + ) + + def test_update_defaults_present(self): + Thing2 = self.classes.Thing2 + s = Session() + + t1, t2 = ( + Thing2(id=1, foo=1, bar=2), + Thing2(id=2, foo=2, bar=3) + ) + + s.add_all([t1, t2]) + s.flush() + + t1.bar = 5 + t2.bar = 10 + + self.assert_sql_execution( + testing.db, + s.commit, + CompiledSQL( + "UPDATE test2 SET bar=%(bar)s WHERE test2.id = %(test2_id)s", + [{'bar': 5, 'test2_id': 1}, {'bar': 10, 'test2_id': 2}], + dialect='postgresql' + ) + ) + class TypeWoBoolTest(fixtures.MappedTest, testing.AssertsExecutionResults): """test support for custom datatypes that return a non-__bool__ value when compared via __eq__(), eg. ticket 3469""" diff --git a/test/orm/test_versioning.py b/test/orm/test_versioning.py index f42069230f..124053d470 100644 --- a/test/orm/test_versioning.py +++ b/test/orm/test_versioning.py @@ -894,19 +894,26 @@ class ServerVersioningTest(fixtures.MappedTest): class Bar(cls.Basic): pass - def _fixture(self, expire_on_commit=True): + def _fixture(self, expire_on_commit=True, eager_defaults=False): Foo, version_table = self.classes.Foo, self.tables.version_table mapper( Foo, version_table, version_id_col=version_table.c.version_id, version_id_generator=False, + eager_defaults=eager_defaults ) s1 = Session(expire_on_commit=expire_on_commit) return s1 def test_insert_col(self): - sess = self._fixture() + self._test_insert_col() + + def test_insert_col_eager_defaults(self): + self._test_insert_col(eager_defaults=True) + + def _test_insert_col(self, **kw): + sess = self._fixture(**kw) f1 = self.classes.Foo(value='f1') sess.add(f1) @@ -935,7 +942,13 @@ class ServerVersioningTest(fixtures.MappedTest): self.assert_sql_execution(testing.db, sess.flush, *statements) def test_update_col(self): - sess = self._fixture() + self._test_update_col() + + def test_update_col_eager_defaults(self): + self._test_update_col(eager_defaults=True) + + def _test_update_col(self, **kw): + sess = self._fixture(**kw) f1 = self.classes.Foo(value='f1') sess.add(f1) diff --git a/test/sql/test_returning.py b/test/sql/test_returning.py index cd9f632b9b..77a0c60075 100644 --- a/test/sql/test_returning.py +++ b/test/sql/test_returning.py @@ -387,6 +387,31 @@ class ReturnDefaultsTest(fixtures.TablesTest): {"data": None, 'upddef': 1} ) + def test_insert_all(self): + t1 = self.tables.t1 + result = testing.db.execute( + t1.insert().values(upddef=1).return_defaults() + ) + eq_( + dict(result.returned_defaults), + {"id": 1, "data": None, "insdef": 0} + ) + + def test_update_all(self): + t1 = self.tables.t1 + testing.db.execute( + t1.insert().values(upddef=1) + ) + result = testing.db.execute( + t1.update(). + values(insdef=2).return_defaults() + ) + eq_( + dict(result.returned_defaults), + {'upddef': 1} + ) + + class ImplicitReturningFlag(fixtures.TestBase): __backend__ = True