]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Look for __sa_reconstructor__ on original_init
authorMike Bayer <mike_mp@zzzcomputing.com>
Fri, 2 Feb 2018 14:36:25 +0000 (09:36 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 2 Feb 2018 14:36:25 +0000 (09:36 -0500)
Fixed bug where the :func:`.orm.reconstructor` event
helper would not be recognized if it were applied to the
``__init__()`` method of the mapped class.

It's not clear when this bug appeared, but was likely
during a refactoring of instrumentation mechanics somewhere
between 0.8 and 1.0.

Change-Id: Iaeb3baffef9e1b40a336d44294e68479f5d65fd3
Fixes: #4178
doc/build/changelog/unreleased_12/4178.rst [new file with mode: 0644]
lib/sqlalchemy/orm/instrumentation.py
lib/sqlalchemy/orm/mapper.py
test/orm/test_mapper.py

diff --git a/doc/build/changelog/unreleased_12/4178.rst b/doc/build/changelog/unreleased_12/4178.rst
new file mode 100644 (file)
index 0000000..5f4844a
--- /dev/null
@@ -0,0 +1,7 @@
+.. change::
+    :tags: bug, orm
+    :tickets: 4178
+
+    Fixed bug where the :func:`.orm.reconstructor` event
+    helper would not be recognized if it were applied to the
+    ``__init__()`` method of the mapped class.
index c17f8ab8d772c53e17140eb82c0c6f4923cdbf91..5a96c57cf7c56cf565eb9d2da17586fa8ff7bb47 100644 (file)
@@ -519,6 +519,7 @@ def __init__(%(apply_pos)s):
     exec(func_text, env)
     __init__ = env['__init__']
     __init__.__doc__ = original__init__.__doc__
+    __init__._sa_original_init = original__init__
 
     if func_defaults:
         __init__.__defaults__ = func_defaults
index 9ce8487993de6242c0c3c64f8f10c97751c801b9..cd9e00b8b0155fe53f4e96a8ce3b2217f01c4c6c 100644 (file)
@@ -1243,6 +1243,10 @@ class Mapper(InspectionAttr):
         event.listen(manager, 'init', _event_on_init, raw=True)
 
         for key, method in util.iterate_attributes(self.class_):
+            if key == '__init__' and hasattr(method, '_sa_original_init'):
+                method = method._sa_original_init
+                if isinstance(method, types.MethodType):
+                    method = method.im_func
             if isinstance(method, types.FunctionType):
                 if hasattr(method, '__sa_reconstructor__'):
                     self._reconstructor = method
index 42d114f6faf2b8b58bc111d8b377d0abbb754b9e..0ff9c12ad37fdb2f60da321c83bbf2de17e6dfc0 100644 (file)
@@ -1694,6 +1694,71 @@ class MapperTest(_fixtures.FixtureTest, AssertsCompiledSQL):
         sess.query(C).first()
         eq_(recon, ['A', 'B', 'C'])
 
+    def test_reconstructor_init(self):
+
+        users = self.tables.users
+
+        recon = []
+
+        class User(object):
+
+            @reconstructor
+            def __init__(self):
+                recon.append('go')
+
+        mapper(User, users)
+
+        User()
+        eq_(recon, ['go'])
+
+        recon[:] = []
+        create_session().query(User).first()
+        eq_(recon, ['go'])
+
+    def test_reconstructor_init_inheritance(self):
+        users = self.tables.users
+
+        recon = []
+
+        class A(object):
+
+            @reconstructor
+            def __init__(self):
+                assert isinstance(self, A)
+                recon.append('A')
+
+        class B(A):
+
+            @reconstructor
+            def __init__(self):
+                assert isinstance(self, B)
+                recon.append('B')
+
+        class C(A):
+
+            @reconstructor
+            def __init__(self):
+                assert isinstance(self, C)
+                recon.append('C')
+
+        mapper(A, users, polymorphic_on=users.c.name,
+               polymorphic_identity='jack')
+        mapper(B, inherits=A, polymorphic_identity='ed')
+        mapper(C, inherits=A, polymorphic_identity='chuck')
+
+        A()
+        B()
+        C()
+        eq_(recon, ['A', 'B', 'C'])
+
+        recon[:] = []
+        sess = create_session()
+        sess.query(A).first()
+        sess.query(B).first()
+        sess.query(C).first()
+        eq_(recon, ['A', 'B', 'C'])
+
+
     def test_unmapped_reconstructor_inheritance(self):
         users = self.tables.users
 
@@ -1854,7 +1919,6 @@ class MapperTest(_fixtures.FixtureTest, AssertsCompiledSQL):
 
         mapper(B, users)
 
-
 class DocumentTest(fixtures.TestBase):
 
     def test_doc_propagate(self):