]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- ColumnDefault callables can now be any kind of compliant callable, previously only...
authorJason Kirtland <jek@discorporate.us>
Mon, 4 Feb 2008 20:49:38 +0000 (20:49 +0000)
committerJason Kirtland <jek@discorporate.us>
Mon, 4 Feb 2008 20:49:38 +0000 (20:49 +0000)
lib/sqlalchemy/schema.py
test/sql/defaults.py

index 44dcb575558fe71d6ad6f8f3c94f3239ba760ac8..98e375507e5133c60c972c65df86709f7348d55c 100644 (file)
@@ -730,22 +730,40 @@ class ColumnDefault(DefaultGenerator):
     def __init__(self, arg, **kwargs):
         super(ColumnDefault, self).__init__(**kwargs)
         if callable(arg):
-            if not inspect.isfunction(arg):
-                self.arg = lambda ctx: arg()
-            else:
-                argspec = inspect.getargspec(arg)
-                if len(argspec[0]) == 0:
-                    self.arg = lambda ctx: arg()
-                else:
-                    defaulted = argspec[3] is not None and len(argspec[3]) or 0
-                    if len(argspec[0]) - defaulted > 1:
-                        raise exceptions.ArgumentError(
-                            "ColumnDefault Python function takes zero or one "
-                            "positional arguments")
-                    else:
-                        self.arg = arg
+            arg = self._maybe_wrap_callable(arg)
+        self.arg = arg
+
+    def _maybe_wrap_callable(self, fn):
+        """Backward compat: Wrap callables that don't accept a context."""
+
+        if inspect.isfunction(fn):
+            inspectable = fn
+        elif inspect.isclass(fn):
+            inspectable = fn.__init__
+        elif hasattr(fn, '__call__'):
+            inspectable = fn.__call__
         else:
-            self.arg = arg
+            # probably not inspectable, try anyways.
+            inspectable = fn
+        try:
+            argspec = inspect.getargspec(inspectable)
+        except TypeError:
+            return lambda ctx: fn()
+
+        positionals = len(argspec[0])
+        if inspect.ismethod(inspectable):
+            positionals -= 1
+
+        if positionals == 0:
+            return lambda ctx: fn()
+
+        defaulted = argspec[3] is not None and len(argspec[3]) or 0
+        if positionals - defaulted > 1:
+            raise exceptions.ArgumentError(
+                "ColumnDefault Python function takes zero or one "
+                "positional arguments")
+        return fn
+
 
     def _visit_name(self):
         if self.for_update:
@@ -783,7 +801,7 @@ class Sequence(DefaultGenerator):
 
     def create(self, bind=None, checkfirst=True):
         """Creates this sequence in the database."""
-        
+
         if bind is None:
             bind = _bind_or_error(self)
         bind.create(self, checkfirst=checkfirst)
index a8311b791ab9d3bb9f513fd49d26cebd8dcad780..9b8485231d1cd3e2576e2c00d7b89c88084572b6 100644 (file)
@@ -104,26 +104,49 @@ class DefaultTest(PersistTest):
         default_generator['x'] = 50
         t.delete().execute()
 
-    def testargsignature(self):
+    def test_bad_argsignature(self):
         ex_msg = \
           "ColumnDefault Python function takes zero or one positional arguments"
 
         def fn1(x, y): pass
         def fn2(x, y, z=3): pass
-        for fn in fn1, fn2:
+        class fn3(object):
+            def __init__(self, x, y):
+                pass
+        class FN4(object):
+            def __call__(self, x, y):
+                pass
+        fn4 = FN4()
+
+        for fn in fn1, fn2, fn3, fn4:
             try:
                 c = ColumnDefault(fn)
-                assert False
+                assert False, str(fn)
             except exceptions.ArgumentError, e:
                 assert str(e) == ex_msg
 
-        def fn3(): pass
-        def fn4(): pass
-        def fn5(x=1): pass
-        def fn6(x=1, y=2, z=3): pass
-        fn7 = list
-
-        for fn in fn3, fn4, fn5, fn6, fn7:
+    def test_argsignature(self):
+        def fn1(): pass
+        def fn2(): pass
+        def fn3(x=1): pass
+        def fn4(x=1, y=2, z=3): pass
+        fn5 = list
+        class fn6(object):
+            def __init__(self, x):
+                pass
+        class fn6(object):
+            def __init__(self, x, y=3):
+                pass
+        class FN7(object):
+            def __call__(self, x):
+                pass
+        fn7 = FN7()
+        class FN8(object):
+            def __call__(self, x, y=3):
+                pass
+        fn8 = FN8()
+
+        for fn in fn1, fn2, fn3, fn4, fn5, fn6, fn7, fn8:
             c = ColumnDefault(fn)
 
     def teststandalone(self):