]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- Repaired the usage of merge() when used with
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 1 Aug 2010 18:07:35 +0000 (14:07 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 1 Aug 2010 18:07:35 +0000 (14:07 -0400)
concrete inheriting mappers.  Such mappers frequently
have so-called "concrete" attributes, which are
subclass attributes that "disable" propagation from
the parent - these needed to allow a merge()
operation to pass through without effect.

CHANGES
lib/sqlalchemy/orm/interfaces.py
lib/sqlalchemy/orm/properties.py
test/orm/inheritance/test_concrete.py

diff --git a/CHANGES b/CHANGES
index a4d60e77fc3ed33efc9d8cfd4349c6fc802a0087..9cc4e1c15d36a84d7be9f6ffd8b8c89b42f8b0d9 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -31,6 +31,13 @@ CHANGES
     version generator is also bypassed if 
     manual changes are present on the attribute.
     [ticket:1857]
+
+  - Repaired the usage of merge() when used with 
+    concrete inheriting mappers.  Such mappers frequently
+    have so-called "concrete" attributes, which are
+    subclass attributes that "disable" propagation from
+    the parent - these needed to allow a merge()
+    operation to pass through without effect.
     
 - sql
   - Changed the scheme used to generate truncated
index 10dd4d8e100affd3f6bf1b067a42e8e157dfc54d..91c2ae403e68b80156f86ae46d5e0f7dcf4c5f1a 100644 (file)
@@ -532,7 +532,8 @@ class MapperProperty(object):
 
         return not self.parent.non_primary
 
-    def merge(self, session, source, dest, load, _recursive):
+    def merge(self, session, source_state, source_dict, dest_state,
+                dest_dict, load, _recursive):
         """Merge the attribute represented by this ``MapperProperty``
         from source to destination object"""
 
index 7008ca7920ecbcf30eef38c73b2f322c84787d3e..cf5f31162f05abe304c602cafb85c00413dbd92e 100644 (file)
@@ -238,6 +238,22 @@ class CompositeProperty(ColumnProperty):
         return str(self.parent.class_.__name__) + "." + self.key
 
 class ConcreteInheritedProperty(MapperProperty):
+    """A 'do nothing' :class:`MapperProperty` that disables 
+    an attribute on a concrete subclass that is only present
+    on the inherited mapper, not the concrete classes' mapper.
+    
+    Cases where this occurs include:
+    
+    * When the superclass mapper is mapped against a 
+      "polymorphic union", which includes all attributes from 
+      all subclasses.
+    * When a relationship() is configured on an inherited mapper,
+      but not on the subclass mapper.  Concrete mappers require
+      that relationship() is configured explicitly on each 
+      subclass. 
+    
+    """
+    
     extension = None
 
     def setup(self, context, entity, path, adapter, **kwargs):
@@ -246,6 +262,10 @@ class ConcreteInheritedProperty(MapperProperty):
     def create_row_processor(self, selectcontext, path, mapper, row, adapter):
         return (None, None)
 
+    def merge(self, session, source_state, source_dict, dest_state,
+                dest_dict, load, _recursive):
+        pass
+        
     def instrument_class(self, mapper):
         def warn():
             raise AttributeError("Concrete %s does not implement "
index a9b0864baa42d7037269168222e433f7bd064d20..2bb25e9e183953352c605e8e2939b844b3ad93e4 100644 (file)
@@ -378,15 +378,21 @@ class PropertyInheritanceTest(_base.MappedTest):
     def define_tables(cls, metadata):
         Table('a_table', metadata, Column('id', Integer,
               primary_key=True, test_needs_autoincrement=True),
-              Column('some_c_id', Integer, ForeignKey('c_table.id')),
+              Column('some_dest_id', Integer, ForeignKey('dest_table.id')),
               Column('aname', String(50)))
         Table('b_table', metadata, Column('id', Integer,
               primary_key=True, test_needs_autoincrement=True),
-              Column('some_c_id', Integer, ForeignKey('c_table.id')),
+              Column('some_dest_id', Integer, ForeignKey('dest_table.id')),
               Column('bname', String(50)))
+
         Table('c_table', metadata, Column('id', Integer,
+            primary_key=True, test_needs_autoincrement=True),
+            Column('some_dest_id', Integer, ForeignKey('dest_table.id')),
+            Column('cname', String(50)))
+            
+        Table('dest_table', metadata, Column('id', Integer,
               primary_key=True, test_needs_autoincrement=True),
-              Column('cname', String(50)))
+              Column('name', String(50)))
 
     @classmethod
     def setup_classes(cls):
@@ -396,57 +402,66 @@ class PropertyInheritanceTest(_base.MappedTest):
 
         class B(A):
             pass
-
-        class C(_base.ComparableEntity):
+        
+        class C(A):
+            pass
+            
+        class Dest(_base.ComparableEntity):
             pass
 
     @testing.resolve_artifact_names
     def test_noninherited_warning(self):
-        mapper(A, a_table, properties={'some_c': relationship(C)})
+        mapper(A, a_table, properties={'some_dest': relationship(Dest)})
         mapper(B, b_table, inherits=A, concrete=True)
-        mapper(C, c_table)
+        mapper(Dest, dest_table)
         b = B()
-        c = C()
-        assert_raises(AttributeError, setattr, b, 'some_c', c)
+        dest = Dest()
+        assert_raises(AttributeError, setattr, b, 'some_dest', dest)
         clear_mappers()
+        
         mapper(A, a_table, properties={'a_id': a_table.c.id})
         mapper(B, b_table, inherits=A, concrete=True)
-        mapper(C, c_table)
+        mapper(Dest, dest_table)
         b = B()
         assert_raises(AttributeError, setattr, b, 'a_id', 3)
         clear_mappers()
+        
         mapper(A, a_table, properties={'a_id': a_table.c.id})
         mapper(B, b_table, inherits=A, concrete=True)
-        mapper(C, c_table)
+        mapper(Dest, dest_table)
 
     @testing.resolve_artifact_names
     def test_inheriting(self):
-        mapper(A, a_table, properties={'some_c': relationship(C,
-               back_populates='many_a')})
+        mapper(A, a_table, properties={
+                'some_dest': relationship(Dest,back_populates='many_a')
+            })
         mapper(B, b_table, inherits=A, concrete=True,
-               properties={'some_c': relationship(C,
-               back_populates='many_b')})
-        mapper(C, c_table, properties={'many_a': relationship(A,
-               back_populates='some_c'), 'many_b': relationship(B,
-               back_populates='some_c')})
+               properties={
+                    'some_dest': relationship(Dest, back_populates='many_b')
+                })
+                    
+        mapper(Dest, dest_table, properties={
+                    'many_a': relationship(A,back_populates='some_dest'), 
+                    'many_b': relationship(B,back_populates='some_dest')
+                })
         sess = sessionmaker()()
-        c1 = C(cname='c1')
-        c2 = C(cname='c2')
-        a1 = A(some_c=c1, aname='a1')
-        a2 = A(some_c=c2, aname='a2')
-        b1 = B(some_c=c1, bname='b1')
-        b2 = B(some_c=c1, bname='b2')
+        dest1 = Dest(name='c1')
+        dest2 = Dest(name='c2')
+        a1 = A(some_dest=dest1, aname='a1')
+        a2 = A(some_dest=dest2, aname='a2')
+        b1 = B(some_dest=dest1, bname='b1')
+        b2 = B(some_dest=dest1, bname='b2')
         assert_raises(AttributeError, setattr, b1, 'aname', 'foo')
         assert_raises(AttributeError, getattr, A, 'bname')
-        assert c2.many_a == [a2]
-        assert c1.many_a == [a1]
-        assert c1.many_b == [b1, b2]
-        sess.add_all([c1, c2])
+        assert dest2.many_a == [a2]
+        assert dest1.many_a == [a1]
+        assert dest1.many_b == [b1, b2]
+        sess.add_all([dest1, dest2])
         sess.commit()
-        assert sess.query(C).filter(C.many_a.contains(a2)).one() is c2
-        assert c2.many_a == [a2]
-        assert c1.many_a == [a1]
-        assert c1.many_b == [b1, b2]
+        assert sess.query(Dest).filter(Dest.many_a.contains(a2)).one() is dest2
+        assert dest2.many_a == [a2]
+        assert dest1.many_a == [a1]
+        assert dest1.many_b == [b1, b2]
         assert sess.query(B).filter(B.bname == 'b1').one() is b1
 
     @testing.resolve_artifact_names
@@ -454,16 +469,17 @@ class PropertyInheritanceTest(_base.MappedTest):
         """test multiple backrefs to the same polymorphically-loading
         attribute."""
 
-        ajoin = polymorphic_union({'a': a_table, 'b': b_table}, 'type',
-                                  'ajoin')
+        ajoin = polymorphic_union({'a': a_table, 'b': b_table, 'c':c_table}, 
+                                'type','ajoin')
         mapper(
             A,
             a_table,
             with_polymorphic=('*', ajoin),
             polymorphic_on=ajoin.c.type,
             polymorphic_identity='a',
-            properties={'some_c': relationship(C,
-                        back_populates='many_a')},
+            properties={
+                'some_dest': relationship(Dest, back_populates='many_a')
+                },
             )
         mapper(
             B,
@@ -471,35 +487,120 @@ class PropertyInheritanceTest(_base.MappedTest):
             inherits=A,
             concrete=True,
             polymorphic_identity='b',
-            properties={'some_c': relationship(C,
-                        back_populates='many_a')},
+            properties={
+                    'some_dest': relationship(Dest, back_populates='many_a')},
             )
-        mapper(C, c_table, properties={'many_a': relationship(A,
-               back_populates='some_c', order_by=ajoin.c.id)})
+
+        mapper(
+            C,
+            c_table,
+            inherits=A,
+            concrete=True,
+            polymorphic_identity='c',
+            properties={
+                    'some_dest': relationship(Dest, back_populates='many_a')},
+            )
+            
+        mapper(Dest, dest_table, properties={
+                'many_a': relationship(A,
+                            back_populates='some_dest', 
+                            order_by=ajoin.c.id)
+                        }
+                )
+                
         sess = sessionmaker()()
-        c1 = C(cname='c1')
-        c2 = C(cname='c2')
-        a1 = A(some_c=c1, aname='a1', id=1)
-        a2 = A(some_c=c2, aname='a2', id=2)
-        b1 = B(some_c=c1, bname='b1', id=3)
-        b2 = B(some_c=c1, bname='b2', id=4)
-        eq_([a2], c2.many_a)
-        eq_([a1, b1, b2], c1.many_a)
-        sess.add_all([c1, c2])
+        dest1 = Dest(name='c1')
+        dest2 = Dest(name='c2')
+        a1 = A(some_dest=dest1, aname='a1', id=1)
+        a2 = A(some_dest=dest2, aname='a2', id=2)
+        b1 = B(some_dest=dest1, bname='b1', id=3)
+        b2 = B(some_dest=dest1, bname='b2', id=4)
+        c1 = C(some_dest=dest1, cname='c1', id=5)
+        c2 = C(some_dest=dest2, cname='c2', id=6)
+        
+        eq_([a2, c2], dest2.many_a)
+        eq_([a1, b1, b2, c1], dest1.many_a)
+        sess.add_all([dest1, dest2])
         sess.commit()
-        assert sess.query(C).filter(C.many_a.contains(a2)).one() is c2
-        assert sess.query(C).filter(C.many_a.contains(b1)).one() is c1
-        eq_(c2.many_a, [a2])
-        eq_(c1.many_a, [a1, b1, b2])
+        
+        assert sess.query(Dest).filter(Dest.many_a.contains(a2)).one() is dest2
+        assert sess.query(Dest).filter(Dest.many_a.contains(b1)).one() is dest1
+        assert sess.query(Dest).filter(Dest.many_a.contains(c2)).one() is dest2
+
+        eq_(dest2.many_a, [a2, c2])
+        eq_(dest1.many_a, [a1, b1, b2, c1])
         sess.expire_all()
 
         def go():
-            eq_([C(many_a=[A(aname='a1'), B(bname='b1'), B(bname='b2'
-                )]), C(many_a=[A(aname='a2')])],
-                sess.query(C).options(joinedload(C.many_a)).order_by(C.id).all())
+            eq_(
+                [
+                    Dest(many_a=[A(aname='a1'), 
+                                    B(bname='b1'), 
+                                    B(bname='b2'),
+                                    C(cname='c1')]), 
+                    Dest(many_a=[A(aname='a2'), C(cname='c2')])],
+                sess.query(Dest).options(joinedload(Dest.many_a)).order_by(Dest.id).all())
 
         self.assert_sql_count(testing.db, go, 1)
 
+    @testing.resolve_artifact_names
+    def test_merge_w_relationship(self):
+        ajoin = polymorphic_union({'a': a_table, 'b': b_table, 'c':c_table}, 
+                                'type','ajoin')
+        mapper(
+            A,
+            a_table,
+            with_polymorphic=('*', ajoin),
+            polymorphic_on=ajoin.c.type,
+            polymorphic_identity='a',
+            properties={
+                'some_dest': relationship(Dest, back_populates='many_a')
+                },
+            )
+        mapper(
+            B,
+            b_table,
+            inherits=A,
+            concrete=True,
+            polymorphic_identity='b',
+            properties={
+                    'some_dest': relationship(Dest, back_populates='many_a')},
+            )
+
+        mapper(
+            C,
+            c_table,
+            inherits=A,
+            concrete=True,
+            polymorphic_identity='c',
+            properties={
+                    'some_dest': relationship(Dest, back_populates='many_a')},
+            )
+            
+        mapper(Dest, dest_table, properties={
+                'many_a': relationship(A,
+                            back_populates='some_dest', 
+                            order_by=ajoin.c.id)
+                        }
+                )
+
+        assert C.some_dest.property.parent is class_mapper(C)
+        assert B.some_dest.property.parent is class_mapper(B)
+        assert A.some_dest.property.parent is class_mapper(A)
+
+        sess = sessionmaker()()
+        dest1 = Dest(name='d1')
+        dest2 = Dest(name='d2')
+        a1 = A(some_dest=dest2, aname='a1')
+        b1 = B(some_dest=dest1, bname='b1')
+        c1 = C(some_dest=dest2, cname='c1')
+        sess.add_all([dest1, dest2, c1, a1, b1])
+        sess.commit()
+        
+        sess = sessionmaker()()
+        merged_c1 = sess.merge(c1)
+        eq_(merged_c1.some_dest.name, 'd2')
+        eq_(merged_c1.some_dest_id, c1.some_dest_id)
 
 class ManyToManyTest(_base.MappedTest):