]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- When generating __init__, use a copy of the func_defaults, not a repr of them.
authorJason Kirtland <jek@discorporate.us>
Mon, 11 Aug 2008 18:27:25 +0000 (18:27 +0000)
committerJason Kirtland <jek@discorporate.us>
Mon, 11 Aug 2008 18:27:25 +0000 (18:27 +0000)
lib/sqlalchemy/orm/attributes.py
test/orm/instrumentation.py

index 5076775bce608dcec3bb27d555ee6e5fafc9d8ab..40878be764f4260b8d00b69cd218074292e83bc8 100644 (file)
@@ -1647,7 +1647,7 @@ def _generate_init(class_, class_manager):
     # FIXME: need to juggle local names to avoid constructor argument
     # clashes.
     func_body = """\
-def __init__(%(args)s):
+def __init__(%(apply_pos)s):
     new_state = class_manager._new_state_if_none(%(self_arg)s)
     if new_state:
         return new_state.initialize_instance(%(apply_kw)s)
@@ -1658,8 +1658,13 @@ def __init__(%(args)s):
     func_text = func_body % func_vars
     #TODO: log debug #print func_text
 
+    func = getattr(original__init__, 'im_func', original__init__)
+    func_defaults = getattr(func, 'func_defaults', None)
+
     env = locals().copy()
     exec func_text in env
     __init__ = env['__init__']
     __init__.__doc__ = original__init__.__doc__
+    if func_defaults:
+        __init__.func_defaults = func_defaults
     return __init__
index 358515cb719e0dbf27f10f0149784c8354b78fad..a9d186632b9b36884c59e583543f4c241ea67075 100644 (file)
@@ -403,6 +403,39 @@ class InitTest(_base.ORMTest):
         obj = C()
         eq_(inits, [(C, 'on_init', C)])
 
+    def test_defaulted_init(self):
+        class X(object):
+            def __init__(self_, a, b=123, c='abc'):
+                self_.a = a
+                self_.b = b
+                self_.c = c
+        attributes.register_class(X)
+
+        o = X('foo')
+        eq_(o.a, 'foo')
+        eq_(o.b, 123)
+        eq_(o.c, 'abc')
+
+        class Y(object):
+            unique = object()
+
+            class OutOfScopeForEval(object):
+                def __repr__(self_):
+                    # misleading repr
+                    return '123'
+
+            outofscope = OutOfScopeForEval()
+
+            def __init__(self_, u=unique, o=outofscope):
+                self_.u = u
+                self_.o = o
+
+        attributes.register_class(Y)
+
+        o = Y()
+        assert o.u is Y.unique
+        assert o.o is Y.outofscope
+
 
 class MapperInitTest(_base.ORMTest):
 
@@ -439,6 +472,7 @@ class MapperInitTest(_base.ORMTest):
         self.assertRaises((AttributeError, TypeError),
                           attributes.instance_state, cobj)
 
+
 class InstrumentationCollisionTest(_base.ORMTest):
     def test_none(self):
         class A(object): pass