From 49090a7a9434245b03a7f867add2401d3a78fead Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Mon, 26 Jun 2006 19:55:48 +0000 Subject: [PATCH] fixed attribute manager's ability to traverse the full set of managed attributes for a descendant class, + 2 unit tests --- lib/sqlalchemy/attributes.py | 5 ++-- test/base/attributes.py | 16 +++++++++++++ test/orm/inheritance.py | 44 ++++++++++++++++++++++++++++++++++-- 3 files changed, 61 insertions(+), 4 deletions(-) diff --git a/lib/sqlalchemy/attributes.py b/lib/sqlalchemy/attributes.py index b7ad5249b0..2bf3363988 100644 --- a/lib/sqlalchemy/attributes.py +++ b/lib/sqlalchemy/attributes.py @@ -519,7 +519,7 @@ class AttributeHistory(object): else: self._deleted_items = [] self._unchanged_items = [] - #print "orig", original, "current", current, "added", self._added_items, "unchanged", self._unchanged_items, "deleted", self._deleted_items + #print "key", attr.key, "orig", original, "current", current, "added", self._added_items, "unchanged", self._unchanged_items, "deleted", self._deleted_items def __iter__(self): return iter(self._current) def added_items(self): @@ -566,7 +566,8 @@ class AttributeManager(object): """returns an iterator of all InstrumentedAttribute objects associated with the given class.""" if not isinstance(class_, type): raise repr(class_) + " is not a type" - for value in class_.__dict__.values(): + for key in dir(class_): + value = getattr(class_, key) if isinstance(value, InstrumentedAttribute): yield value diff --git a/test/base/attributes.py b/test/base/attributes.py index 4b8bfd39ae..19eedd0f69 100644 --- a/test/base/attributes.py +++ b/test/base/attributes.py @@ -183,6 +183,22 @@ class AttributesTest(PersistTest): assert x.element2 == 'this is the shared attr' assert y.element2 == 'this is the shared attr' + def testinheritance2(self): + """test that the attribute manager can properly traverse the managed attributes of an object, + if the object is of a descendant class with managed attributes in the parent class""" + class Foo(object):pass + class Bar(Foo):pass + manager = attributes.AttributeManager() + manager.register_attribute(Foo, 'element', uselist=False) + x = Bar() + x.element = 'this is the element' + hist = manager.get_history(x, 'element') + assert hist.added_items() == ['this is the element'] + manager.commit(x) + hist = manager.get_history(x, 'element') + assert hist.added_items() == [] + assert hist.unchanged_items() == ['this is the element'] + def testlazyhistory(self): """tests that history functions work with lazy-loading attributes""" class Foo(object):pass diff --git a/test/orm/inheritance.py b/test/orm/inheritance.py index 842a63a266..bca0ffde32 100644 --- a/test/orm/inheritance.py +++ b/test/orm/inheritance.py @@ -442,8 +442,11 @@ class InheritTest7(testbase.AssertMixin): metadata.create_all() def tearDownAll(self): metadata.drop_all() - - def testbasic(self): + def tearDown(self): + for t in metadata.table_iterator(reverse=True): + t.delete().execute() + + def testone(self): class User(object):pass class Role(object):pass class Admin(User):pass @@ -469,6 +472,43 @@ class InheritTest7(testbase.AssertMixin): sess.flush() assert user_roles.count().scalar() == 1 + + def testtwo(self): + class User(object): + def __init__(self, email=None, password=None): + self.email = email + self.password = password + + class Role(object): + def __init__(self, description=None): + self.description = description + + class Admin(User):pass + + role_mapper = mapper(Role, roles) + user_mapper = mapper(User, users, properties = { + 'roles' : relation(Role, secondary=user_roles, lazy=False, private=False) + } + ) + + admin_mapper = mapper(Admin, admins, inherits=user_mapper) + + # create roles + adminrole = Role('admin') + + sess = create_session() + sess.save(adminrole) + sess.flush() + + # create admin user + a = Admin(email='tim', password='admin') + a.roles.append(adminrole) + sess.save(a) + sess.flush() + + a.password = 'sadmin' + sess.flush() + assert user_roles.count().scalar() == 1 if __name__ == "__main__": testbase.main() -- 2.47.3