]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- get util.get_callable_argspec() to be completely bulletproof for 2.6-3.4,
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 2 Mar 2014 18:59:06 +0000 (13:59 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 2 Mar 2014 18:59:06 +0000 (13:59 -0500)
methods, classes, builtins, functools.partial(), everything known so far
- use get_callable_argspec() within ColumnDefault._maybe_wrap_callable, re: #2979

lib/sqlalchemy/event/attr.py
lib/sqlalchemy/sql/schema.py
lib/sqlalchemy/util/langhelpers.py
test/base/test_utils.py

index 8bb458330d20f0ca1024b2d737e0cb1d833e476f..b44aeefc70f0179ad430227acb571650b624adc4 100644 (file)
@@ -69,7 +69,7 @@ class _DispatchDescriptor(RefCollection):
         if self.legacy_signatures:
             try:
                 argspec = util.get_callable_argspec(fn, no_self=True)
-            except ValueError:
+            except TypeError:
                 pass
             else:
                 fn = legacy._wrap_fn_for_legacy(self, fn, argspec)
index 2614c08c8332476eca534c549ac5bbdb28cb2820..ce31310e765f62fab892b9c03f6654f662c54962 100644 (file)
@@ -1836,42 +1836,18 @@ class ColumnDefault(DefaultGenerator):
     def _maybe_wrap_callable(self, fn):
         """Wrap callables that don't accept a context.
 
-        The alternative here is to require that
-        a simple callable passed to "default" would need
-        to be of the form "default=lambda ctx: datetime.now".
-        That is the more "correct" way to go, but the case
-        of using a zero-arg callable for "default" is so
-        much more prominent than the context-specific one
-        I'm having trouble justifying putting that inconvenience
-        on everyone.
+        This is to allow easy compatiblity with default callables
+        that aren't specific to accepting of a context.
 
         """
-        # TODO: why aren't we using a util.langhelpers function
-        # for this?  e.g. get_callable_argspec
-
-        if isinstance(fn, (types.BuiltinMethodType, types.BuiltinFunctionType)):
-            return lambda ctx: fn()
-        elif inspect.isfunction(fn) or inspect.ismethod(fn):
-            inspectable = fn
-        elif inspect.isclass(fn):
-            inspectable = fn.__init__
-        elif hasattr(fn, '__call__'):
-            inspectable = fn.__call__
-        else:
-            # probably not inspectable, try anyways.
-            inspectable = fn
         try:
-            argspec = inspect.getargspec(inspectable)
+            argspec = util.get_callable_argspec(fn, no_self=True)
         except TypeError:
             return lambda ctx: fn()
 
         defaulted = argspec[3] is not None and len(argspec[3]) or 0
         positionals = len(argspec[0]) - defaulted
 
-        # Py3K compat - no unbound methods
-        if inspect.ismethod(inspectable) or inspect.isclass(fn):
-            positionals -= 1
-
         if positionals == 0:
             return lambda ctx: fn()
         elif positionals == 1:
index 94ddb242ca7c4f788abd5ec4858479444450af45..0af8da381d1c504886ebc47f2a9d7045697b1ac5 100644 (file)
@@ -260,20 +260,42 @@ def get_func_kwargs(func):
 
     return compat.inspect_getargspec(func)[0]
 
-def get_callable_argspec(fn, no_self=False):
-    if isinstance(fn, types.FunctionType):
-        return compat.inspect_getargspec(fn)
-    elif isinstance(fn, types.MethodType) and no_self:
-        spec = compat.inspect_getargspec(fn.__func__)
-        return compat.ArgSpec(spec.args[1:], spec.varargs, spec.keywords, spec.defaults)
+def get_callable_argspec(fn, no_self=False, _is_init=False):
+    """Return the argument signature for any callable.
+
+    All pure-Python callables are accepted, including
+    functions, methods, classes, objects with __call__;
+    builtins and other edge cases like functools.partial() objects
+    raise a TypeError.
+
+    """
+    if inspect.isbuiltin(fn):
+        raise TypeError("Can't inspect builtin: %s" % fn)
+    elif inspect.isfunction(fn):
+        if _is_init and no_self:
+            spec = compat.inspect_getargspec(fn)
+            return compat.ArgSpec(spec.args[1:], spec.varargs,
+                            spec.keywords, spec.defaults)
+        else:
+            return compat.inspect_getargspec(fn)
+    elif inspect.ismethod(fn):
+        if no_self and (_is_init or fn.__self__):
+            spec = compat.inspect_getargspec(fn.__func__)
+            return compat.ArgSpec(spec.args[1:], spec.varargs,
+                            spec.keywords, spec.defaults)
+        else:
+            return compat.inspect_getargspec(fn.__func__)
+    elif inspect.isclass(fn):
+        return get_callable_argspec(fn.__init__, no_self=no_self, _is_init=True)
     elif hasattr(fn, '__func__'):
         return compat.inspect_getargspec(fn.__func__)
-    elif hasattr(fn, '__call__') and \
-        not hasattr(fn.__call__, '__call__'):  # functools.partial does this;
-                                               # not much we can do
-        return get_callable_argspec(fn.__call__)
+    elif hasattr(fn, '__call__'):
+        if inspect.ismethod(fn.__call__):
+            return get_callable_argspec(fn.__call__, no_self=no_self)
+        else:
+            raise TypeError("Can't inspect callable: %s" % fn)
     else:
-        raise ValueError("Can't inspect function: %s" % fn)
+        raise TypeError("Can't inspect callable: %s" % fn)
 
 def format_argspec_plus(fn, grouped=True):
     """Returns a dictionary of formatted, introspected function arguments.
index 4ff17e8cc8f6c3c9bba661b3f2a75f083f71c237..34b707e6cfccb290f151556f01543d3402caf120 100644 (file)
@@ -1377,6 +1377,35 @@ class ArgInspectionTest(fixtures.TestBase):
             (['x', 'y'], None, 'kw', None)
         )
 
+    def test_callable_argspec_fn_no_self(self):
+        def foo(x, y, **kw):
+            pass
+        eq_(
+            get_callable_argspec(foo, no_self=True),
+            (['x', 'y'], None, 'kw', None)
+        )
+
+    def test_callable_argspec_fn_no_self_but_self(self):
+        def foo(self, x, y, **kw):
+            pass
+        eq_(
+            get_callable_argspec(foo, no_self=True),
+            (['self', 'x', 'y'], None, 'kw', None)
+        )
+
+    def test_callable_argspec_py_builtin(self):
+        import datetime
+        assert_raises(
+            TypeError,
+            get_callable_argspec, datetime.datetime.now
+        )
+
+    def test_callable_argspec_obj_init(self):
+        assert_raises(
+            TypeError,
+            get_callable_argspec, object
+        )
+
     def test_callable_argspec_method(self):
         class Foo(object):
             def foo(self, x, y, **kw):
@@ -1386,6 +1415,62 @@ class ArgInspectionTest(fixtures.TestBase):
             (['self', 'x', 'y'], None, 'kw', None)
         )
 
+    def test_callable_argspec_instance_method_no_self(self):
+        class Foo(object):
+            def foo(self, x, y, **kw):
+                pass
+        eq_(
+            get_callable_argspec(Foo().foo, no_self=True),
+            (['x', 'y'], None, 'kw', None)
+        )
+
+    def test_callable_argspec_unbound_method_no_self(self):
+        class Foo(object):
+            def foo(self, x, y, **kw):
+                pass
+        eq_(
+            get_callable_argspec(Foo.foo, no_self=True),
+            (['self', 'x', 'y'], None, 'kw', None)
+        )
+
+    def test_callable_argspec_init(self):
+        class Foo(object):
+            def __init__(self, x, y):
+                pass
+
+        eq_(
+            get_callable_argspec(Foo),
+            (['self', 'x', 'y'], None, None, None)
+        )
+
+    def test_callable_argspec_init_no_self(self):
+        class Foo(object):
+            def __init__(self, x, y):
+                pass
+
+        eq_(
+            get_callable_argspec(Foo, no_self=True),
+            (['x', 'y'], None, None, None)
+        )
+
+    def test_callable_argspec_call(self):
+        class Foo(object):
+            def __call__(self, x, y):
+                pass
+        eq_(
+            get_callable_argspec(Foo()),
+            (['self', 'x', 'y'], None, None, None)
+        )
+
+    def test_callable_argspec_call_no_self(self):
+        class Foo(object):
+            def __call__(self, x, y):
+                pass
+        eq_(
+            get_callable_argspec(Foo(), no_self=True),
+            (['x', 'y'], None, None, None)
+        )
+
     def test_callable_argspec_partial(self):
         from functools import partial
         def foo(x, y, z, **kw):
@@ -1393,7 +1478,7 @@ class ArgInspectionTest(fixtures.TestBase):
         bar = partial(foo, 5)
 
         assert_raises(
-            ValueError,
+            TypeError,
             get_callable_argspec, bar
         )