]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Corrected behavior of get_cls_kwargs and friends
authorJason Kirtland <jek@discorporate.us>
Thu, 24 Jan 2008 00:08:40 +0000 (00:08 +0000)
committerJason Kirtland <jek@discorporate.us>
Thu, 24 Jan 2008 00:08:40 +0000 (00:08 +0000)
lib/sqlalchemy/util.py
test/base/utils.py

index 4f30f76ba50fe881011d2b92275d15aaba7fa646..bdcaf37f07b401c5b087b043fbf646619da79e86 100644 (file)
@@ -4,8 +4,9 @@
 # This module is part of SQLAlchemy and is released under
 # the MIT License: http://www.opensource.org/licenses/mit-license.php
 
-import itertools, sys, warnings, sets, weakref
+import inspect, itertools, sets, sys, warnings, weakref
 import __builtin__
+types = __import__('types')
 
 from sqlalchemy import exceptions
 
@@ -181,20 +182,37 @@ class ArgSingleton(type):
             return instance
 
 def get_cls_kwargs(cls):
-    """Return the full set of legal kwargs for the given `cls`."""
+    """Return the full set of inherited kwargs for the given `cls`.
+
+    Probes a class's __init__ method, collecting all named arguments.  If the
+    __init__ defines a **kwargs catch-all, then the constructor is presumed to
+    pass along unrecognized keywords to it's base classes, and the collection
+    process is repeated recursively on each of the bases.
+    """
 
-    kw = []
     for c in cls.__mro__:
-        cons = c.__init__
-        if hasattr(cons, 'func_code'):
-            for vn in cons.func_code.co_varnames:
-                if vn != 'self':
-                    kw.append(vn)
-    return kw
+        if '__init__' in c.__dict__:
+            stack = [c]
+            break
+    else:
+        return []
+
+    args = Set()
+    while stack:
+        class_ = stack.pop()
+        ctr = class_.__dict__.get('__init__', False)
+        if not ctr or not isinstance(ctr, types.FunctionType):
+            continue
+        names, _, has_kw, _ = inspect.getargspec(ctr)
+        args |= Set(names)
+        if has_kw:
+            stack.extend(class_.__bases__)
+    args.discard('self')
+    return list(args)
 
 def get_func_kwargs(func):
     """Return the full set of legal kwargs for the given `func`."""
-    return [vn for vn in func.func_code.co_varnames]
+    return inspect.getargspec(func)[0]
 
 # from paste.deploy.converters
 def asbool(obj):
index 5a034e0b0f5d235d86945345c984fcb3f22fa1bb..837eb058f0623a801b7184297d5865741d101d7a 100644 (file)
@@ -305,5 +305,72 @@ class DictlikeIteritemsTest(unittest.TestCase):
         self._notok(duck6())
 
 
+class ArgInspectionTest(PersistTest):
+    def test_get_cls_kwargs(self):
+        class A(object):
+            def __init__(self, a):
+                pass
+        class A1(A):
+            def __init__(self, a1):
+                pass
+        class A11(A1):
+            def __init__(self, a11, **kw):
+                pass
+        class B(object):
+            def __init__(self, b, **kw):
+                pass
+        class B1(B):
+            def __init__(self, b1, **kw):
+                pass
+        class AB(A, B):
+            def __init__(self, ab):
+                pass
+        class BA(B, A):
+            def __init__(self, ba, **kwargs):
+                pass
+        class BA1(BA):
+            pass
+        class CAB(A, B):
+            pass
+        class CBA(B, A):
+            pass
+        class CAB1(A, B1):
+            pass
+        class CB1A(B1, A):
+            pass
+        class D(object):
+            pass
+
+        def test(cls, *expected):
+            self.assertEquals(set(util.get_cls_kwargs(cls)), set(expected))
+
+        test(A, 'a')
+        test(A1, 'a1')
+        test(A11, 'a11', 'a1')
+        test(B, 'b')
+        test(B1, 'b1', 'b')
+        test(AB, 'ab')
+        test(BA, 'ba', 'b', 'a')
+        test(BA1, 'ba', 'b', 'a')
+        test(CAB, 'a')
+        test(CBA, 'b')
+        test(CAB1, 'a')
+        test(CB1A, 'b1', 'b')
+        test(D)
+
+    def test_get_func_kwargs(self):
+        def f1(): pass
+        def f2(foo): pass
+        def f3(*foo): pass
+        def f4(**foo): pass
+
+        def test(fn, *expected):
+            self.assertEquals(set(util.get_func_kwargs(fn)), set(expected))
+
+        test(f1)
+        test(f2, 'foo')
+        test(f3)
+        test(f4)
+
 if __name__ == "__main__":
     testenv.main()