]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
fixed attribute manager's ability to traverse the full set of managed attributes...
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 26 Jun 2006 19:55:48 +0000 (19:55 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 26 Jun 2006 19:55:48 +0000 (19:55 +0000)
lib/sqlalchemy/attributes.py
test/base/attributes.py
test/orm/inheritance.py

index b7ad5249b0c326d54118a5c70192fad7efea4962..2bf3363988dc28ec4fe923e9c6604d52f469317a 100644 (file)
@@ -519,7 +519,7 @@ class AttributeHistory(object):
                 else:
                     self._deleted_items = []
                 self._unchanged_items = []
-        #print "orig", original, "current", current, "added", self._added_items, "unchanged", self._unchanged_items, "deleted", self._deleted_items
+        #print "key", attr.key, "orig", original, "current", current, "added", self._added_items, "unchanged", self._unchanged_items, "deleted", self._deleted_items
     def __iter__(self):
         return iter(self._current)
     def added_items(self):
@@ -566,7 +566,8 @@ class AttributeManager(object):
         """returns an iterator of all InstrumentedAttribute objects associated with the given class."""
         if not isinstance(class_, type):
             raise repr(class_) + " is not a type"
-        for value in class_.__dict__.values():
+        for key in dir(class_):
+            value = getattr(class_, key)
             if isinstance(value, InstrumentedAttribute):
                 yield value
                 
index 4b8bfd39ae774a4dd9858cc7c6e77bd1c27460bd..19eedd0f693d3d863c414a3c6e9fb8109101098a 100644 (file)
@@ -183,6 +183,22 @@ class AttributesTest(PersistTest):
         assert x.element2 == 'this is the shared attr'
         assert y.element2 == 'this is the shared attr'
 
+    def testinheritance2(self):
+        """test that the attribute manager can properly traverse the managed attributes of an object,
+        if the object is of a descendant class with managed attributes in the parent class"""
+        class Foo(object):pass
+        class Bar(Foo):pass
+        manager = attributes.AttributeManager()
+        manager.register_attribute(Foo, 'element', uselist=False)
+        x = Bar()
+        x.element = 'this is the element'
+        hist = manager.get_history(x, 'element')
+        assert hist.added_items() == ['this is the element']
+        manager.commit(x)
+        hist = manager.get_history(x, 'element')
+        assert hist.added_items() == []
+        assert hist.unchanged_items() == ['this is the element']
+
     def testlazyhistory(self):
         """tests that history functions work with lazy-loading attributes"""
         class Foo(object):pass
index 842a63a26662c622a90751586ad8f35569e8442e..bca0ffde320affbe36c96e98680eff362651a6d7 100644 (file)
@@ -442,8 +442,11 @@ class InheritTest7(testbase.AssertMixin):
         metadata.create_all()
     def tearDownAll(self):
         metadata.drop_all()
-
-    def testbasic(self):
+    def tearDown(self):
+        for t in metadata.table_iterator(reverse=True):
+            t.delete().execute()
+            
+    def testone(self):
         class User(object):pass
         class Role(object):pass
         class Admin(User):pass
@@ -469,6 +472,43 @@ class InheritTest7(testbase.AssertMixin):
         sess.flush()
         
         assert user_roles.count().scalar() == 1
+
+    def testtwo(self):
+        class User(object):
+            def __init__(self, email=None, password=None):
+                self.email = email
+                self.password = password
+
+        class Role(object):
+            def __init__(self, description=None):
+                self.description = description
+
+        class Admin(User):pass
+
+        role_mapper = mapper(Role, roles)
+        user_mapper = mapper(User, users, properties = {
+                'roles' : relation(Role, secondary=user_roles, lazy=False, private=False)
+            }
+        )
+
+        admin_mapper = mapper(Admin, admins, inherits=user_mapper) 
+
+        # create roles
+        adminrole = Role('admin')
+
+        sess = create_session()
+        sess.save(adminrole)
+        sess.flush()
+
+        # create admin user
+        a = Admin(email='tim', password='admin')
+        a.roles.append(adminrole)
+        sess.save(a)
+        sess.flush()
+
+        a.password = 'sadmin'
+        sess.flush()
+        assert user_roles.count().scalar() == 1
         
 if __name__ == "__main__":    
     testbase.main()