]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- Ignore old-style classes when building inheritance graphs. [ticket:1078]
authorJason Kirtland <jek@discorporate.us>
Fri, 15 Aug 2008 22:54:35 +0000 (22:54 +0000)
committerJason Kirtland <jek@discorporate.us>
Fri, 15 Aug 2008 22:54:35 +0000 (22:54 +0000)
lib/sqlalchemy/util.py
test/base/utils.py

index dd104531193b5fe5fceef664f736967de035f8b5..c2d4bf6abf842f57a600d5d8035784d169dfe62a 100644 (file)
@@ -397,15 +397,23 @@ def class_hierarchy(cls):
     class_hierarchy(class A(object)) returns (A, object), not A plus every
     class systemwide that derives from object.
 
+    Old-style classes are discarded and hierarchies rooted on them
+    will not be descended.
+
     """
+    if isinstance(cls, types.ClassType):
+        return list()
     hier = set([cls])
     process = list(cls.__mro__)
     while process:
         c = process.pop()
-        for b in [_ for _ in c.__bases__ if _ not in hier]:
+        if isinstance(c, types.ClassType):
+            continue
+        for b in (_ for _ in c.__bases__
+                  if _ not in hier and not isinstance(_, types.ClassType)):
             process.append(b)
             hier.add(b)
-        if c.__module__ == '__builtin__':
+        if c.__module__ == '__builtin__' or not hasattr(c, '__subclasses__'):
             continue
         for s in [_ for _ in c.__subclasses__() if _ not in hier]:
             process.append(s)
@@ -414,10 +422,10 @@ def class_hierarchy(cls):
 
 def iterate_attributes(cls):
     """iterate all the keys and attributes associated with a class, without using getattr().
-    
+
     Does not use getattr() so that class-sensitive descriptors (i.e. property.__get__())
     are not called.
-    
+
     """
     keys = dir(cls)
     for key in keys:
@@ -425,7 +433,7 @@ def iterate_attributes(cls):
             if key in c.__dict__:
                 yield (key, c.__dict__[key])
                 break
-                
+
 # from paste.deploy.converters
 def asbool(obj):
     if isinstance(obj, (str, unicode)):
@@ -1121,7 +1129,7 @@ class ScopedRegistry(object):
             return object.__new__(_TLocalRegistry)
         else:
             return object.__new__(cls)
-        
+
     def __init__(self, createfunc, scopefunc):
         self.createfunc = createfunc
         self.scopefunc = scopefunc
index 3ce956a16d609bb3af7e9aea8c3f53c422dfc0f6..2c4edc6929b8d95d216097435307c08594288a08 100644 (file)
@@ -884,5 +884,40 @@ class AsInterfaceTest(TestBase):
         obj = {'foo': 123}
         self.assertRaises(TypeError, util.as_interface, obj, cls=self.Something)
 
+
+class TestClassHierarchy(TestBase):
+    def test_object(self):
+        eq_(set(util.class_hierarchy(object)), set((object,)))
+
+    def test_single(self):
+        class A(object):
+            pass
+
+        class B(object):
+            pass
+
+        eq_(set(util.class_hierarchy(A)), set((A, object)))
+        eq_(set(util.class_hierarchy(B)), set((B, object)))
+
+        class C(A, B):
+            pass
+
+        eq_(set(util.class_hierarchy(A)), set((A, B, C, object)))
+        eq_(set(util.class_hierarchy(B)), set((A, B, C, object)))
+
+    def test_oldstyle_mixin(self):
+        class A(object):
+            pass
+
+        class Mixin:
+            pass
+
+        class B(A, Mixin):
+            pass
+
+        eq_(set(util.class_hierarchy(B)), set((A, B, object)))
+        eq_(set(util.class_hierarchy(Mixin)), set())
+        eq_(set(util.class_hierarchy(A)), set((A, B, object)))
+
 if __name__ == "__main__":
     testenv.main()