From: Jason Kirtland Date: Wed, 22 Aug 2007 15:15:52 +0000 (+0000) Subject: Adjusted ColumnDefault default function fitness check to only insure that a given... X-Git-Tag: rel_0_4beta4~6 X-Git-Url: http://git.ipfire.org/gitweb.cgi?a=commitdiff_plain;h=ea44deff14a59da731ffeeafbbcf721921f1404d;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Adjusted ColumnDefault default function fitness check to only insure that a given function had no more than one non-defaulted positional arg. --- diff --git a/lib/sqlalchemy/schema.py b/lib/sqlalchemy/schema.py index b6f345be22..1dacad3de9 100644 --- a/lib/sqlalchemy/schema.py +++ b/lib/sqlalchemy/schema.py @@ -735,10 +735,14 @@ class ColumnDefault(DefaultGenerator): argspec = inspect.getargspec(arg) if len(argspec[0]) == 0: self.arg = lambda ctx: arg() - elif len(argspec[0]) != 1: - raise exceptions.ArgumentError("ColumnDefault Python function takes zero or one positional arguments") else: - self.arg = arg + defaulted = argspec[3] is not None and len(argspec[3]) or 0 + if len(argspec[0]) - defaulted > 1: + raise exceptions.ArgumentError( + "ColumnDefault Python function takes zero or one " + "positional arguments") + else: + self.arg = arg else: self.arg = arg diff --git a/test/sql/defaults.py b/test/sql/defaults.py index 953eb7a354..854c9dc69a 100644 --- a/test/sql/defaults.py +++ b/test/sql/defaults.py @@ -99,13 +99,26 @@ class DefaultTest(PersistTest): t.delete().execute() def testargsignature(self): - def mydefault(x, y): - pass - try: - c = ColumnDefault(mydefault) - assert False - except exceptions.ArgumentError, e: - assert str(e) == "ColumnDefault Python function takes zero or one positional arguments", str(e) + ex_msg = \ + "ColumnDefault Python function takes zero or one positional arguments" + + def fn1(x, y): pass + def fn2(x, y, z=3): pass + for fn in fn1, fn2: + try: + c = ColumnDefault(fn) + assert False + except exceptions.ArgumentError, e: + assert str(e) == ex_msg + + def fn3(): pass + def fn4(): pass + def fn5(x=1): pass + def fn6(x=1, y=2, z=3): pass + fn7 = list + + for fn in fn3, fn4, fn5, fn6, fn7: + c = ColumnDefault(fn) def teststandalone(self): c = testbase.db.engine.contextual_connect()