def __init__(self, arg, **kwargs):
super(ColumnDefault, self).__init__(**kwargs)
if callable(arg):
- if not inspect.isfunction(arg):
- self.arg = lambda ctx: arg()
- else:
- argspec = inspect.getargspec(arg)
- if len(argspec[0]) == 0:
- self.arg = lambda ctx: arg()
- else:
- defaulted = argspec[3] is not None and len(argspec[3]) or 0
- if len(argspec[0]) - defaulted > 1:
- raise exceptions.ArgumentError(
- "ColumnDefault Python function takes zero or one "
- "positional arguments")
- else:
- self.arg = arg
+ arg = self._maybe_wrap_callable(arg)
+ self.arg = arg
+
+ def _maybe_wrap_callable(self, fn):
+ """Backward compat: Wrap callables that don't accept a context."""
+
+ if inspect.isfunction(fn):
+ inspectable = fn
+ elif inspect.isclass(fn):
+ inspectable = fn.__init__
+ elif hasattr(fn, '__call__'):
+ inspectable = fn.__call__
else:
- self.arg = arg
+ # probably not inspectable, try anyways.
+ inspectable = fn
+ try:
+ argspec = inspect.getargspec(inspectable)
+ except TypeError:
+ return lambda ctx: fn()
+
+ positionals = len(argspec[0])
+ if inspect.ismethod(inspectable):
+ positionals -= 1
+
+ if positionals == 0:
+ return lambda ctx: fn()
+
+ defaulted = argspec[3] is not None and len(argspec[3]) or 0
+ if positionals - defaulted > 1:
+ raise exceptions.ArgumentError(
+ "ColumnDefault Python function takes zero or one "
+ "positional arguments")
+ return fn
+
def _visit_name(self):
if self.for_update:
def create(self, bind=None, checkfirst=True):
"""Creates this sequence in the database."""
-
+
if bind is None:
bind = _bind_or_error(self)
bind.create(self, checkfirst=checkfirst)
default_generator['x'] = 50
t.delete().execute()
- def testargsignature(self):
+ def test_bad_argsignature(self):
ex_msg = \
"ColumnDefault Python function takes zero or one positional arguments"
def fn1(x, y): pass
def fn2(x, y, z=3): pass
- for fn in fn1, fn2:
+ class fn3(object):
+ def __init__(self, x, y):
+ pass
+ class FN4(object):
+ def __call__(self, x, y):
+ pass
+ fn4 = FN4()
+
+ for fn in fn1, fn2, fn3, fn4:
try:
c = ColumnDefault(fn)
- assert False
+ assert False, str(fn)
except exceptions.ArgumentError, e:
assert str(e) == ex_msg
- def fn3(): pass
- def fn4(): pass
- def fn5(x=1): pass
- def fn6(x=1, y=2, z=3): pass
- fn7 = list
-
- for fn in fn3, fn4, fn5, fn6, fn7:
+ def test_argsignature(self):
+ def fn1(): pass
+ def fn2(): pass
+ def fn3(x=1): pass
+ def fn4(x=1, y=2, z=3): pass
+ fn5 = list
+ class fn6(object):
+ def __init__(self, x):
+ pass
+ class fn6(object):
+ def __init__(self, x, y=3):
+ pass
+ class FN7(object):
+ def __call__(self, x):
+ pass
+ fn7 = FN7()
+ class FN8(object):
+ def __call__(self, x, y=3):
+ pass
+ fn8 = FN8()
+
+ for fn in fn1, fn2, fn3, fn4, fn5, fn6, fn7, fn8:
c = ColumnDefault(fn)
def teststandalone(self):