From: Jason Kirtland Date: Mon, 4 Feb 2008 20:49:38 +0000 (+0000) Subject: - ColumnDefault callables can now be any kind of compliant callable, previously only... X-Git-Tag: rel_0_4_3~50 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=0de289921c4d52798248cfacbacc04ccad12cec9;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git - ColumnDefault callables can now be any kind of compliant callable, previously only actual functions were allowed. --- diff --git a/lib/sqlalchemy/schema.py b/lib/sqlalchemy/schema.py index 44dcb57555..98e375507e 100644 --- a/lib/sqlalchemy/schema.py +++ b/lib/sqlalchemy/schema.py @@ -730,22 +730,40 @@ class ColumnDefault(DefaultGenerator): def __init__(self, arg, **kwargs): super(ColumnDefault, self).__init__(**kwargs) if callable(arg): - if not inspect.isfunction(arg): - self.arg = lambda ctx: arg() - else: - argspec = inspect.getargspec(arg) - if len(argspec[0]) == 0: - self.arg = lambda ctx: arg() - else: - defaulted = argspec[3] is not None and len(argspec[3]) or 0 - if len(argspec[0]) - defaulted > 1: - raise exceptions.ArgumentError( - "ColumnDefault Python function takes zero or one " - "positional arguments") - else: - self.arg = arg + arg = self._maybe_wrap_callable(arg) + self.arg = arg + + def _maybe_wrap_callable(self, fn): + """Backward compat: Wrap callables that don't accept a context.""" + + if inspect.isfunction(fn): + inspectable = fn + elif inspect.isclass(fn): + inspectable = fn.__init__ + elif hasattr(fn, '__call__'): + inspectable = fn.__call__ else: - self.arg = arg + # probably not inspectable, try anyways. + inspectable = fn + try: + argspec = inspect.getargspec(inspectable) + except TypeError: + return lambda ctx: fn() + + positionals = len(argspec[0]) + if inspect.ismethod(inspectable): + positionals -= 1 + + if positionals == 0: + return lambda ctx: fn() + + defaulted = argspec[3] is not None and len(argspec[3]) or 0 + if positionals - defaulted > 1: + raise exceptions.ArgumentError( + "ColumnDefault Python function takes zero or one " + "positional arguments") + return fn + def _visit_name(self): if self.for_update: @@ -783,7 +801,7 @@ class Sequence(DefaultGenerator): def create(self, bind=None, checkfirst=True): """Creates this sequence in the database.""" - + if bind is None: bind = _bind_or_error(self) bind.create(self, checkfirst=checkfirst) diff --git a/test/sql/defaults.py b/test/sql/defaults.py index a8311b791a..9b8485231d 100644 --- a/test/sql/defaults.py +++ b/test/sql/defaults.py @@ -104,26 +104,49 @@ class DefaultTest(PersistTest): default_generator['x'] = 50 t.delete().execute() - def testargsignature(self): + def test_bad_argsignature(self): ex_msg = \ "ColumnDefault Python function takes zero or one positional arguments" def fn1(x, y): pass def fn2(x, y, z=3): pass - for fn in fn1, fn2: + class fn3(object): + def __init__(self, x, y): + pass + class FN4(object): + def __call__(self, x, y): + pass + fn4 = FN4() + + for fn in fn1, fn2, fn3, fn4: try: c = ColumnDefault(fn) - assert False + assert False, str(fn) except exceptions.ArgumentError, e: assert str(e) == ex_msg - def fn3(): pass - def fn4(): pass - def fn5(x=1): pass - def fn6(x=1, y=2, z=3): pass - fn7 = list - - for fn in fn3, fn4, fn5, fn6, fn7: + def test_argsignature(self): + def fn1(): pass + def fn2(): pass + def fn3(x=1): pass + def fn4(x=1, y=2, z=3): pass + fn5 = list + class fn6(object): + def __init__(self, x): + pass + class fn6(object): + def __init__(self, x, y=3): + pass + class FN7(object): + def __call__(self, x): + pass + fn7 = FN7() + class FN8(object): + def __call__(self, x, y=3): + pass + fn8 = FN8() + + for fn in fn1, fn2, fn3, fn4, fn5, fn6, fn7, fn8: c = ColumnDefault(fn) def teststandalone(self):