]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Consider version_id_prop when emitting bulk UPDATE
authorMike Bayer <mike_mp@zzzcomputing.com>
Wed, 5 Oct 2016 20:55:43 +0000 (16:55 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Thu, 10 Nov 2016 21:30:11 +0000 (16:30 -0500)
The version id needs to be part of _changed_dict()
so that the value is present to send to
_emit_update_statements()

Change-Id: Ia85f0ef7714296a75cdc6c88674805afbbe752c8
Fixes: #3781
doc/build/changelog/changelog_10.rst
lib/sqlalchemy/orm/persistence.py
test/orm/test_bulk.py

index 4b28e40a71df4127af1094c6dc1050b2335f2240..bd419dfd719b76867d0ad71c38d5930b19afd632 100644 (file)
         collection of the mapped table, thereby interfering with the
         initialization of relationships.
 
+    .. change::
+        :tags: bug, orm
+        :tickets: 3781
+        :versions: 1.1.4
+
+        Fixed bug in :meth:`.Session.bulk_save` where an UPDATE would
+        not function correctly in conjunction with a mapping that
+        implements a version id counter.
+
 .. changelog::
     :version: 1.0.15
     :released: September 1, 2016
index bf51a2a833003198e297a03038f284f7412183ac..2f7acba3aeccd53f9a15fa96e669fdae7c64befa 100644 (file)
@@ -84,11 +84,16 @@ def _bulk_update(mapper, mappings, session_transaction,
 
     cached_connections = _cached_connection_dict(base_mapper)
 
+    search_keys = mapper._primary_key_propkeys
+    if mapper._version_id_prop:
+        search_keys = set([mapper._version_id_prop.key]).union(search_keys)
+
     def _changed_dict(mapper, state):
         return dict(
             (k, v)
             for k, v in state.dict.items() if k in state.committed_state or k
-            in mapper._primary_key_propkeys
+            in search_keys
+
         )
 
     if isstates:
index 8a8fd004de4c93d0e467f3c139aebc96e23afc95..cd569fa73c13e1f4b52dda9538907c0201361d31 100644 (file)
@@ -13,6 +13,57 @@ class BulkTest(testing.AssertsExecutionResults):
     run_define_tables = 'each'
 
 
+class BulkInsertUpdateVersionId(BulkTest, fixtures.MappedTest):
+    @classmethod
+    def define_tables(cls, metadata):
+        Table('version_table', metadata,
+              Column('id', Integer, primary_key=True,
+                     test_needs_autoincrement=True),
+              Column('version_id', Integer, nullable=False),
+              Column('value', String(40), nullable=False))
+
+    @classmethod
+    def setup_classes(cls):
+        class Foo(cls.Comparable):
+            pass
+
+    @classmethod
+    def setup_mappers(cls):
+        Foo, version_table = cls.classes.Foo, cls.tables.version_table
+
+        mapper(Foo, version_table, version_id_col=version_table.c.version_id)
+
+    def test_bulk_insert_via_save(self):
+        Foo = self.classes.Foo
+
+        s = Session()
+
+        s.bulk_save_objects([Foo(value='value')])
+
+        eq_(
+            s.query(Foo).all(),
+            [Foo(version_id=1, value='value')]
+        )
+
+    def test_bulk_update_via_save(self):
+        Foo = self.classes.Foo
+
+        s = Session()
+
+        s.add(Foo(value='value'))
+        s.commit()
+
+        f1 = s.query(Foo).first()
+        f1.value = 'new value'
+        s.bulk_save_objects([f1])
+        s.expunge_all()
+
+        eq_(
+            s.query(Foo).all(),
+            [Foo(version_id=2, value='new value')]
+        )
+
+
 class BulkInsertUpdateTest(BulkTest, _fixtures.FixtureTest):
 
     @classmethod