]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Fixed a bug where the routine to detect the correct kwargs
authorMike Bayer <mike_mp@zzzcomputing.com>
Thu, 23 May 2013 16:59:53 +0000 (12:59 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Thu, 23 May 2013 16:59:53 +0000 (12:59 -0400)
being sent to :func:`.create_engine` would fail in some cases,
such as with the Sybase dialect.
[ticket:2732]

doc/build/changelog/changelog_08.rst
lib/sqlalchemy/util/langhelpers.py
test/base/test_utils.py

index 9a09f104570550430c9c68021dcb1aa02418db21..fbc79b108e9b3dec4968088ab5986504a4b58372 100644 (file)
@@ -6,6 +6,14 @@
 .. changelog::
     :version: 0.8.2
 
+    .. change::
+      :tags: bug, engine, sybase
+      :tickets: 2732
+
+      Fixed a bug where the routine to detect the correct kwargs
+      being sent to :func:`.create_engine` would fail in some cases,
+      such as with the Sybase dialect.
+
     .. change::
       :tags: bug, orm
       :tickets: 2481
index f6d9164e6c0d6725ddce19a575a14b0388aeb0c9..d82aefdeae7d67f27cec863804540b9b442863d4 100644 (file)
@@ -150,7 +150,7 @@ class PluginLoader(object):
         self.impls[name] = load
 
 
-def get_cls_kwargs(cls):
+def get_cls_kwargs(cls, _set=None):
     """Return the full set of inherited kwargs for the given `cls`.
 
     Probes a class's __init__ method, collecting all named arguments.  If the
@@ -162,33 +162,31 @@ def get_cls_kwargs(cls):
     No anonymous tuple arguments please !
 
     """
+    toplevel = _set == None
+    if toplevel:
+        _set = set()
 
-    for c in cls.__mro__:
-        if '__init__' in c.__dict__:
-            stack = set([c])
-            break
-    else:
-        return []
-
-    args = set()
-    while stack:
-        class_ = stack.pop()
-        ctr = class_.__dict__.get('__init__', False)
-        if (not ctr or
-            not isinstance(ctr, types.FunctionType) or
-                not isinstance(ctr.func_code, types.CodeType)):
-            stack.update(class_.__bases__)
-            continue
+    ctr = cls.__dict__.get('__init__', False)
 
-        # this is shorthand for
-        # names, _, has_kw, _ = inspect.getargspec(ctr)
+    has_init = ctr and isinstance(ctr, types.FunctionType) and \
+        isinstance(ctr.func_code, types.CodeType)
 
+    if has_init:
         names, has_kw = inspect_func_args(ctr)
-        args.update(names)
-        if has_kw:
-            stack.update(class_.__bases__)
-    args.discard('self')
-    return args
+        _set.update(names)
+
+        if not has_kw and not toplevel:
+            return None
+
+    if not has_init or has_kw:
+        for c in cls.__bases__:
+            if get_cls_kwargs(c, _set) is None:
+                break
+
+    _set.discard('self')
+    return _set
+
+
 
 try:
     from inspect import CO_VARKEYWORDS
index af881af179e8dc3f71da074709f5935e1876b493..b28d26e712637c99b86fe2a74467872ce83cd72a 100644 (file)
@@ -1084,6 +1084,10 @@ class ArgInspectionTest(fixtures.TestBase):
             def __init__(self, b1, **kw):
                 pass
 
+        class B2(B):
+            def __init__(self, b2):
+                pass
+
         class AB(A, B):
             def __init__(self, ab):
                 pass
@@ -1101,15 +1105,27 @@ class ArgInspectionTest(fixtures.TestBase):
         class CBA(B, A):
             pass
 
+        class CB1A1(B1, A1):
+            pass
+
         class CAB1(A, B1):
             pass
 
         class CB1A(B1, A):
             pass
 
+        class CB2A(B2, A):
+            pass
+
         class D(object):
             pass
 
+        class BA2(B, A):
+            pass
+
+        class A11B1(A11, B1):
+            pass
+
         def test(cls, *expected):
             eq_(set(util.get_cls_kwargs(cls)), set(expected))
 
@@ -1122,10 +1138,14 @@ class ArgInspectionTest(fixtures.TestBase):
         test(BA, 'ba', 'b', 'a')
         test(BA1, 'ba', 'b', 'a')
         test(CAB, 'a')
-        test(CBA, 'b')
+        test(CBA, 'b', 'a')
         test(CAB1, 'a')
-        test(CB1A, 'b1', 'b')
+        test(CB1A, 'b1', 'b', 'a')
+        test(CB2A, 'b2')
+        test(CB1A1, "a1", "b1", "b")
         test(D)
+        test(BA2, "a", "b")
+        test(A11B1, "a1", "a11", "b", "b1")
 
     def test_get_func_kwargs(self):