]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Adjusted ColumnDefault default function fitness check to only insure that a given...
authorJason Kirtland <jek@discorporate.us>
Wed, 22 Aug 2007 15:15:52 +0000 (15:15 +0000)
committerJason Kirtland <jek@discorporate.us>
Wed, 22 Aug 2007 15:15:52 +0000 (15:15 +0000)
lib/sqlalchemy/schema.py
test/sql/defaults.py

index b6f345be2254e2720812b2bcc0433598a13ba9f0..1dacad3de98ff635602f50261e9499ebdca0ec34 100644 (file)
@@ -735,10 +735,14 @@ class ColumnDefault(DefaultGenerator):
                 argspec = inspect.getargspec(arg)
                 if len(argspec[0]) == 0:
                     self.arg = lambda ctx: arg()
-                elif len(argspec[0]) != 1:
-                    raise exceptions.ArgumentError("ColumnDefault Python function takes zero or one positional arguments")
                 else:
-                    self.arg = arg
+                    defaulted = argspec[3] is not None and len(argspec[3]) or 0
+                    if len(argspec[0]) - defaulted > 1:
+                        raise exceptions.ArgumentError(
+                            "ColumnDefault Python function takes zero or one "
+                            "positional arguments")
+                    else:
+                        self.arg = arg
         else:
             self.arg = arg
 
index 953eb7a3547fb9204190fe7d32a1951ee217aaf8..854c9dc69a052a70064a56e21b9efd459feb566b 100644 (file)
@@ -99,13 +99,26 @@ class DefaultTest(PersistTest):
         t.delete().execute()
     
     def testargsignature(self):
-        def mydefault(x, y):
-            pass
-        try:
-            c = ColumnDefault(mydefault)
-            assert False
-        except exceptions.ArgumentError, e:
-            assert str(e) == "ColumnDefault Python function takes zero or one positional arguments", str(e)
+        ex_msg = \
+          "ColumnDefault Python function takes zero or one positional arguments"
+
+        def fn1(x, y): pass
+        def fn2(x, y, z=3): pass
+        for fn in fn1, fn2:
+            try:
+                c = ColumnDefault(fn)
+                assert False
+            except exceptions.ArgumentError, e:
+                assert str(e) == ex_msg
+
+        def fn3(): pass
+        def fn4(): pass
+        def fn5(x=1): pass
+        def fn6(x=1, y=2, z=3): pass
+        fn7 = list
+
+        for fn in fn3, fn4, fn5, fn6, fn7:
+            c = ColumnDefault(fn)
             
     def teststandalone(self):
         c = testbase.db.engine.contextual_connect()