From: Mike Bayer Date: Wed, 5 Oct 2016 20:55:43 +0000 (-0400) Subject: Consider version_id_prop when emitting bulk UPDATE X-Git-Tag: rel_1_0_16~6 X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=e2a976e916575a305363ae8a5841c4058ece90bd;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Consider version_id_prop when emitting bulk UPDATE The version id needs to be part of _changed_dict() so that the value is present to send to _emit_update_statements() Change-Id: Ia85f0ef7714296a75cdc6c88674805afbbe752c8 Fixes: #3781 --- diff --git a/doc/build/changelog/changelog_10.rst b/doc/build/changelog/changelog_10.rst index bf21ce2f75..d3dfbd9120 100644 --- a/doc/build/changelog/changelog_10.rst +++ b/doc/build/changelog/changelog_10.rst @@ -85,6 +85,15 @@ collection of the mapped table, thereby interfering with the initialization of relationships. + .. change:: + :tags: bug, orm + :tickets: 3781 + :versions: 1.1.4 + + Fixed bug in :meth:`.Session.bulk_save` where an UPDATE would + not function correctly in conjunction with a mapping that + implements a version id counter. + .. changelog:: :version: 1.0.15 :released: September 1, 2016 diff --git a/lib/sqlalchemy/orm/persistence.py b/lib/sqlalchemy/orm/persistence.py index d2922dccb6..7a27e09337 100644 --- a/lib/sqlalchemy/orm/persistence.py +++ b/lib/sqlalchemy/orm/persistence.py @@ -82,11 +82,15 @@ def _bulk_update(mapper, mappings, session_transaction, cached_connections = _cached_connection_dict(base_mapper) + search_keys = mapper._primary_key_propkeys + if mapper._version_id_prop: + search_keys = set([mapper._version_id_prop.key]).union(search_keys) + def _changed_dict(mapper, state): return dict( (k, v) for k, v in state.dict.items() if k in state.committed_state or k - in mapper._primary_key_propkeys + in search_keys ) if isstates: diff --git a/test/orm/test_bulk.py b/test/orm/test_bulk.py index 4cc6905be1..fcc16531d6 100644 --- a/test/orm/test_bulk.py +++ b/test/orm/test_bulk.py @@ -13,6 +13,57 @@ class BulkTest(testing.AssertsExecutionResults): run_define_tables = 'each' +class BulkInsertUpdateVersionId(BulkTest, fixtures.MappedTest): + @classmethod + def define_tables(cls, metadata): + Table('version_table', metadata, + Column('id', Integer, primary_key=True, + test_needs_autoincrement=True), + Column('version_id', Integer, nullable=False), + Column('value', String(40), nullable=False)) + + @classmethod + def setup_classes(cls): + class Foo(cls.Comparable): + pass + + @classmethod + def setup_mappers(cls): + Foo, version_table = cls.classes.Foo, cls.tables.version_table + + mapper(Foo, version_table, version_id_col=version_table.c.version_id) + + def test_bulk_insert_via_save(self): + Foo = self.classes.Foo + + s = Session() + + s.bulk_save_objects([Foo(value='value')]) + + eq_( + s.query(Foo).all(), + [Foo(version_id=1, value='value')] + ) + + def test_bulk_update_via_save(self): + Foo = self.classes.Foo + + s = Session() + + s.add(Foo(value='value')) + s.commit() + + f1 = s.query(Foo).first() + f1.value = 'new value' + s.bulk_save_objects([f1]) + s.expunge_all() + + eq_( + s.query(Foo).all(), + [Foo(version_id=2, value='new value')] + ) + + class BulkInsertUpdateTest(BulkTest, _fixtures.FixtureTest): @classmethod