]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- add one more #2583 test to cover the "multiple PK switch" use case
authorMike Bayer <mike_mp@zzzcomputing.com>
Wed, 3 Oct 2012 15:10:42 +0000 (11:10 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Wed, 3 Oct 2012 15:10:42 +0000 (11:10 -0400)
lib/sqlalchemy/orm/session.py
test/orm/test_transaction.py

index faa9e5a83c737ad77c9eae11b98d5159e1027004..1df9d45ca7ba0e3da18ed4b90f8039d71344af7c 100644 (file)
@@ -1288,7 +1288,7 @@ class Session(_SessionClassMethods):
                     # map (see test/orm/test_naturalpks.py ReversePKsTest)
                     self.identity_map.discard(state)
                     if state in self.transaction._key_switches:
-                        orig_key = self.transaction._key_switches[0]
+                        orig_key = self.transaction._key_switches[state][0]
                     else:
                         orig_key = state.key
                     self.transaction._key_switches[state] = (orig_key, instance_key)
index 49fdd2864915f8753857a14098575f6623c8fbf7..9c19d5bdf3cc1022303fc5e4b530a197b81b7e68 100644 (file)
@@ -1184,6 +1184,39 @@ class NaturalPKRollbackTest(fixtures.MappedTest):
         assert s.identity_map[(User, ('u1',))] is u1
         assert s.identity_map[(User, ('u2',))] is u2
 
+    def test_multiple_key_replaced_by_update(self):
+        users, User = self.tables.users, self.classes.User
+
+        mapper(User, users)
+
+        u1 = User(name='u1')
+        u2 = User(name='u2')
+        u3 = User(name='u3')
+
+        s = Session()
+        s.add_all([u1, u2, u3])
+        s.commit()
+
+        s.delete(u1)
+        s.delete(u2)
+        s.flush()
+
+        u3.name = 'u1'
+        s.flush()
+
+        u3.name = 'u2'
+        s.flush()
+
+        s.rollback()
+
+        assert u1 in s
+        assert u2 in s
+        assert u3 in s
+
+        assert s.identity_map[(User, ('u1',))] is u1
+        assert s.identity_map[(User, ('u2',))] is u2
+        assert s.identity_map[(User, ('u3',))] is u3
+
     def test_key_replaced_by_oob_insert(self):
         users, User = self.tables.users, self.classes.User