From ea44deff14a59da731ffeeafbbcf721921f1404d Mon Sep 17 00:00:00 2001 From: Jason Kirtland Date: Wed, 22 Aug 2007 15:15:52 +0000 Subject: [PATCH] Adjusted ColumnDefault default function fitness check to only insure that a given function had no more than one non-defaulted positional arg. --- lib/sqlalchemy/schema.py | 10 +++++++--- test/sql/defaults.py | 27 ++++++++++++++++++++------- 2 files changed, 27 insertions(+), 10 deletions(-) diff --git a/lib/sqlalchemy/schema.py b/lib/sqlalchemy/schema.py index b6f345be22..1dacad3de9 100644 --- a/lib/sqlalchemy/schema.py +++ b/lib/sqlalchemy/schema.py @@ -735,10 +735,14 @@ class ColumnDefault(DefaultGenerator): argspec = inspect.getargspec(arg) if len(argspec[0]) == 0: self.arg = lambda ctx: arg() - elif len(argspec[0]) != 1: - raise exceptions.ArgumentError("ColumnDefault Python function takes zero or one positional arguments") else: - self.arg = arg + defaulted = argspec[3] is not None and len(argspec[3]) or 0 + if len(argspec[0]) - defaulted > 1: + raise exceptions.ArgumentError( + "ColumnDefault Python function takes zero or one " + "positional arguments") + else: + self.arg = arg else: self.arg = arg diff --git a/test/sql/defaults.py b/test/sql/defaults.py index 953eb7a354..854c9dc69a 100644 --- a/test/sql/defaults.py +++ b/test/sql/defaults.py @@ -99,13 +99,26 @@ class DefaultTest(PersistTest): t.delete().execute() def testargsignature(self): - def mydefault(x, y): - pass - try: - c = ColumnDefault(mydefault) - assert False - except exceptions.ArgumentError, e: - assert str(e) == "ColumnDefault Python function takes zero or one positional arguments", str(e) + ex_msg = \ + "ColumnDefault Python function takes zero or one positional arguments" + + def fn1(x, y): pass + def fn2(x, y, z=3): pass + for fn in fn1, fn2: + try: + c = ColumnDefault(fn) + assert False + except exceptions.ArgumentError, e: + assert str(e) == ex_msg + + def fn3(): pass + def fn4(): pass + def fn5(x=1): pass + def fn6(x=1, y=2, z=3): pass + fn7 = list + + for fn in fn3, fn4, fn5, fn6, fn7: + c = ColumnDefault(fn) def teststandalone(self): c = testbase.db.engine.contextual_connect() -- 2.47.3