]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- fixed fairly critical bug whereby the same instance could be listed
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 5 Jan 2008 18:26:28 +0000 (18:26 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 5 Jan 2008 18:26:28 +0000 (18:26 +0000)
more than once in the unitofwork.new collection; most typically
reproduced when using a combination of inheriting mappers and
ScopedSession.mapper, as the multiple __init__ calls per instance
could save() the object with distinct _state objects

CHANGES
lib/sqlalchemy/orm/attributes.py
test/orm/attributes.py
test/orm/session.py

diff --git a/CHANGES b/CHANGES
index 4deba3b494cf3123b8333ce43ff622a203448bd0..ce9cc73160affbd6cba89f746c17968b772cf01f 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -4,8 +4,14 @@ CHANGES
 0.4.3
 -----
 - orm
-    - Added very rudimentary yielding iterator behavior to Query.  Call
-      query.yield_per(<number of rows>) and evaluate the Query in an
+    - fixed fairly critical bug whereby the same instance could be listed
+      more than once in the unitofwork.new collection; most typically
+      reproduced when using a combination of inheriting mappers and 
+      ScopedSession.mapper, as the multiple __init__ calls per instance
+      could save() the object with distinct _state objects
+      
+    - added very rudimentary yielding iterator behavior to Query.  Call
+      query.yield_per(<number of rows>) and evaluate the Query in an 
       iterative context; every collection of N rows will be packaged up
       and yielded.  Use this method with extreme caution since it does
       not attempt to reconcile eagerly loaded collections across
index 21b8a4e641933a4efdd048a16e53b9e148cbd6c0..91bf130344788df6efe1951c0b8700ed520d5f9f 100644 (file)
@@ -1122,7 +1122,8 @@ def register_class(class_, extra_init=None, on_exception=None, deferred_scalar_l
     doinit = False
 
     def init(instance, *args, **kwargs):
-        instance._state = InstanceState(instance)
+        if not hasattr(instance, '_state'):
+            instance._state = InstanceState(instance)
 
         if extra_init:
             extra_init(class_, oldinit, instance, args, kwargs)
index dd15e41e5694d55493f753f9ff3c5fdf047d1db3..a03123c9d86bfbdea42bd169821d2c88ab786d77 100644 (file)
@@ -249,6 +249,25 @@ class AttributesTest(PersistTest):
         assert x.element2 == 'this is the shared attr'
         assert y.element2 == 'this is the shared attr'
 
+    def test_no_double_state(self):
+        states = set()
+        class Foo(object):
+            def __init__(self):
+                states.add(self._state)
+        class Bar(Foo):
+            def __init__(self):
+                states.add(self._state)
+                Foo.__init__(self)
+        
+        
+        attributes.register_class(Foo)
+        attributes.register_class(Bar)
+        
+        b = Bar()
+        self.assertEquals(len(states), 1)
+        self.assertEquals(list(states)[0].obj(), b)
+        
+
     def test_inheritance2(self):
         """test that the attribute manager can properly traverse the managed attributes of an object,
         if the object is of a descendant class with managed attributes in the parent class"""
index 8baf752754b9188b033046363f761f58238aeac9..8cd633745a9ce69f7a3fd54989b9628bb74dbfe8 100644 (file)
@@ -798,6 +798,23 @@ class SessionTest(AssertMixin):
         u3 = sess.query(User).get(u1.user_id)
         assert u3 is not u1 and u3 is not u2 and u3.user_name == u1.user_name
 
+    def test_no_double_save(self):
+        sess = create_session()
+        class Foo(object):
+            def __init__(self):
+                sess.save(self)
+        class Bar(Foo):
+            def __init__(self):
+                sess.save(self)
+                Foo.__init__(self)
+        mapper(Foo, users)
+        mapper(Bar, users)
+
+        b = Bar()
+        assert b in sess
+        assert len(list(sess)) == 1
+        
+        
 class ScopedSessionTest(ORMTest):
 
     def define_tables(self, metadata):
@@ -894,7 +911,7 @@ class ScopedMapperTest(PersistTest):
             pass
         Session.mapper(Baz, table2, extension=ext)
         assert hasattr(Baz, 'query')
-    
+
     def test_validating_constructor(self):
         s2 = SomeObject(someid=12)
         s3 = SomeOtherObject(someid=123, bogus=345)