]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- wrap ColumnDefault empty arg callables like functools.wraps, setting __name__,...
authorMartin J. Hsu <martin.hsu@gmail.com>
Fri, 25 Sep 2015 08:15:28 +0000 (16:15 +0800)
committerMartin J. Hsu <martin.hsu@gmail.com>
Thu, 15 Oct 2015 02:46:33 +0000 (10:46 +0800)
lib/sqlalchemy/sql/schema.py
lib/sqlalchemy/util/__init__.py
lib/sqlalchemy/util/langhelpers.py
test/sql/test_defaults.py

index 137208584e71ba82a1a94e4c369c31fff777429e..0c433d16e4c3e4feaa9a3723fa5f8513c65fe0ad 100644 (file)
@@ -1981,13 +1981,14 @@ class ColumnDefault(DefaultGenerator):
         try:
             argspec = util.get_callable_argspec(fn, no_self=True)
         except TypeError:
-            return lambda ctx: fn()
+            return util.wrap_callable(fn)
 
         defaulted = argspec[3] is not None and len(argspec[3]) or 0
         positionals = len(argspec[0]) - defaulted
 
         if positionals == 0:
-            return lambda ctx: fn()
+            return util.wrap_callable(fn)
+
         elif positionals == 1:
             return fn
         else:
index ed968f1681cec0c127a6a980c855430d30f22797..36a81dbce882ade27b3382d00f6656fbfddffa9f 100644 (file)
@@ -36,7 +36,7 @@ from .langhelpers import iterate_attributes, class_hierarchy, \
     generic_repr, counter, PluginLoader, hybridproperty, hybridmethod, \
     safe_reraise,\
     get_callable_argspec, only_once, attrsetter, ellipses_string, \
-    warn_limited, map_bits, MemoizedSlots, EnsureKWArgType
+    warn_limited, map_bits, MemoizedSlots, EnsureKWArgType, wrap_callable
 
 from .deprecations import warn_deprecated, warn_pending_deprecation, \
     deprecated, pending_deprecation, inject_docstring_text
index 743afccfdd6678134330b5e4ebc16a8596b665a7..9f259aea3a0488027a4dac504cca451854e221a7 100644 (file)
@@ -1377,3 +1377,29 @@ class EnsureKWArgType(type):
             return fn(*arg)
         return update_wrapper(wrap, fn)
 
+
+def wrap_callable(fn):
+    """Wrap callable and set __name__, __doc__, and __module__.
+
+    :param fn:
+      object with __call__ method
+    """
+    if hasattr(fn, '__name__'):
+        _f = update_wrapper(lambda ctx: fn(), fn)
+        _f.__doc__ = _f.__doc__ or fn.__name__
+        return _f
+    else:
+        _f = lambda ctx: fn()
+        _f.__name__ = fn.__class__.__name__
+        _f.__module__ = fn.__module__
+
+        if hasattr(fn.__call__, '__doc__') and fn.__call__.__doc__:
+            _f.__doc__ = fn.__call__.__doc__ 
+        elif hasattr(fn.__class__, '__doc__') and fn.__class__.__doc__:
+            _f.__doc__ = fn.__class__.__doc__
+        elif fn.__doc__:
+            _f.__doc__ = fn.__doc__
+        else:
+            _f.__doc__ = fn.__class__.__name__
+
+        return _f
index 673085cf77b2934426c81aad6d8b42bf94ab1fa7..e2250e8346acb1cdceb777f5beb63196f2064d36 100644 (file)
@@ -301,6 +301,67 @@ class DefaultTest(fixtures.TestBase):
             c = sa.ColumnDefault(fn)
             c.arg("context")
 
+    def test_wrapping_update_wrapper_fn(self):
+        def my_fancy_default():
+            """run the fancy default"""
+            return 10
+
+        c = sa.ColumnDefault(my_fancy_default)
+        eq_(c.arg.__name__, "my_fancy_default")
+        eq_(c.arg.__doc__, "run the fancy default")
+
+    def test_wrapping_update_wrapper_fn_nodocstring(self):
+        def my_fancy_default():
+            return 10
+
+        c = sa.ColumnDefault(my_fancy_default)
+        eq_(c.arg.__name__, "my_fancy_default")
+        eq_(c.arg.__doc__, "my_fancy_default")
+
+    def test_wrapping_update_wrapper_cls(self):
+        class MyFancyDefault(object):
+            """a fancy default"""
+
+            def __call__(self):
+                """run the fancy default"""
+                return 10
+
+        c = sa.ColumnDefault(MyFancyDefault())
+        eq_(c.arg.__name__, "MyFancyDefault")
+        eq_(c.arg.__doc__, "run the fancy default")
+
+    def test_wrapping_update_wrapper_cls_noclassdocstring(self):
+        class MyFancyDefault(object):
+
+            def __call__(self):
+                """run the fancy default"""
+                return 10
+
+        c = sa.ColumnDefault(MyFancyDefault())
+        eq_(c.arg.__name__, "MyFancyDefault")
+        eq_(c.arg.__doc__, "run the fancy default")
+
+    def test_wrapping_update_wrapper_cls_nomethoddocstring(self):
+        class MyFancyDefault(object):
+            """a fancy default"""
+
+            def __call__(self):
+                return 10
+
+        c = sa.ColumnDefault(MyFancyDefault())
+        eq_(c.arg.__name__, "MyFancyDefault")
+        eq_(c.arg.__doc__, "a fancy default")
+
+    def test_wrapping_update_wrapper_cls_noclassdocstring_nomethoddocstring(self):
+        class MyFancyDefault(object):
+
+            def __call__(self):
+                return 10
+
+        c = sa.ColumnDefault(MyFancyDefault())
+        eq_(c.arg.__name__, "MyFancyDefault")
+        eq_(c.arg.__doc__, "MyFancyDefault")
+
     @testing.fails_on('firebird', 'Data type unknown')
     def test_standalone(self):
         c = testing.db.engine.contextual_connect()