]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
merged -r6134:6172 of trunk
authorMike Bayer <mike_mp@zzzcomputing.com>
Wed, 22 Jul 2009 20:33:40 +0000 (20:33 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Wed, 22 Jul 2009 20:33:40 +0000 (20:33 +0000)
CHANGES
lib/sqlalchemy/orm/mapper.py
lib/sqlalchemy/orm/unitofwork.py
test/orm/inheritance/test_basic.py
test/orm/inheritance/test_concrete.py

diff --git a/CHANGES b/CHANGES
index c6cb0ff2c7e7fe02a668ac61cbf2a51163a3c117..78f426aaa46c1adcc750625b0f3204dff51a504d 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -5,6 +5,18 @@ CHANGES
 =======
 0.5.6
 =====
+- orm
+    - Fixed bug whereby inheritance discriminator part of a
+      composite primary key would fail on updates.
+      Continuation of [ticket:1300].
+    
+    - relations() now have greater ability to be "overridden",
+      meaning a subclass that explicitly specifies a relation()
+      overriding that of the parent class will be honored
+      during a flush.  This is currently to support 
+      many-to-many relations from concrete inheritance setups.
+      Outside of that use case, YMMV.  [ticket:1477]
+      
 - sql
     - Fixed a bug in extract() introduced in 0.5.4 whereby
       the string "field" argument was getting treated as a 
index 6417397b883828c7df7ef939027d0ca82f034c8e..d155f66d12237b1daa5957d2bbdd1c531d2f9a0a 100644 (file)
@@ -1328,7 +1328,7 @@ class Mapper(object):
                                 history = attributes.get_state_history(state, prop.key, passive=True)
                                 if history.added:
                                     hasdata = True
-                        elif mapper.polymorphic_on and mapper.polymorphic_on.shares_lineage(col):
+                        elif mapper.polymorphic_on and mapper.polymorphic_on.shares_lineage(col) and col not in pks:
                             pass
                         else:
                             if post_update_cols is not None and col not in post_update_cols:
index 682df9323f0208f4e22f40719f556e44c083a3b1..ec28628abc0c3432988a2433d92f22292a9f040a 100644 (file)
@@ -430,6 +430,15 @@ class UOWTask(object):
                     yield rec
         return collection
 
+    def _polymorphic_collection_filtered(fn):
+
+        def collection(self, mappers):
+            for task in self.polymorphic_tasks:
+                if task.mapper in mappers:
+                    for rec in fn(task):
+                        yield rec
+        return collection
+
     @property
     def elements(self):
         return self._objects.values()
@@ -438,6 +447,10 @@ class UOWTask(object):
     def polymorphic_elements(self):
         return self.elements
 
+    @_polymorphic_collection_filtered
+    def filter_polymorphic_elements(self):
+        return self.elements
+
     @property
     def polymorphic_tosave_elements(self):
         return [rec for rec in self.polymorphic_elements if not rec.isdelete]
@@ -642,7 +655,19 @@ class UOWDependencyProcessor(object):
     def __init__(self, processor, targettask):
         self.processor = processor
         self.targettask = targettask
-
+        prop = processor.prop
+        
+        # define a set of mappers which
+        # will filter the lists of entities
+        # this UOWDP processes.  this allows
+        # MapperProperties to be overridden
+        # at least for concrete mappers.
+        self._mappers = set([
+            m
+            for m in self.processor.parent.polymorphic_iterator()
+            if m._props[prop.key] is prop
+        ]).union(self.processor.mapper.polymorphic_iterator())
+            
     def __repr__(self):
         return "UOWDependencyProcessor(%s, %s)" % (str(self.processor), str(self.targettask))
 
@@ -673,12 +698,16 @@ class UOWDependencyProcessor(object):
             return elem.state
 
         ret = False
-        elements = [getobj(elem) for elem in self.targettask.polymorphic_tosave_elements if self not in elem.preprocessed]
+        elements = [getobj(elem) for elem in 
+                        self.targettask.filter_polymorphic_elements(self._mappers)
+                        if self not in elem.preprocessed and not elem.isdelete]
         if elements:
             ret = True
             self.processor.preprocess_dependencies(self.targettask, elements, trans, delete=False)
 
-        elements = [getobj(elem) for elem in self.targettask.polymorphic_todelete_elements if self not in elem.preprocessed]
+        elements = [getobj(elem) for elem in 
+                        self.targettask.filter_polymorphic_elements(self._mappers)
+                        if self not in elem.preprocessed and elem.isdelete]
         if elements:
             ret = True
             self.processor.preprocess_dependencies(self.targettask, elements, trans, delete=True)
@@ -687,10 +716,10 @@ class UOWDependencyProcessor(object):
     def execute(self, trans, delete):
         """process all objects contained within this ``UOWDependencyProcessor``s target task."""
 
-        if delete:
-            elements = self.targettask.polymorphic_todelete_elements
-        else:
-            elements = self.targettask.polymorphic_tosave_elements
+
+        elements = [e for e in 
+                    self.targettask.filter_polymorphic_elements(self._mappers) 
+                    if e.isdelete==delete]
 
         self.processor.process_dependencies(
             self.targettask, 
index 88f9bf7520658d27bd76b2d151bcd7243d03ef29..435f26cbaee2b8c24056c350b057bfe5ecd5690c 100644 (file)
@@ -986,6 +986,14 @@ class PKDiscriminatorTest(_base.MappedTest):
         assert a.id
         assert a.type == 2
         
+        p.name='p1new'
+        a.name='a1new'
+        s.flush()
+        
+        s.expire_all()
+        assert a.name=='a1new'
+        assert p.name=='p1new'
+        
         
 class DeleteOrphanTest(_base.MappedTest):
     @classmethod
index 4907e7c60d4fb66ae0aecbe4bf7798de8a7b8333..3a78be9d7bc9c60ee05466cee0ac77a2716df2ae 100644 (file)
@@ -465,7 +465,61 @@ class PropertyInheritanceTest(_base.MappedTest):
                 sess.query(C).options(eagerload(C.many_a)).order_by(C.id).all(),
             )
         self.assert_sql_count(testing.db, go, 1)
+
+class ManyToManyTest(_base.MappedTest):
+    @classmethod
+    def define_tables(cls, metadata):
+        Table("base", metadata,
+            Column('id', Integer, primary_key=True, test_needs_autoincrement=True)
+        )
+        Table("sub", metadata,
+            Column('id', Integer, primary_key=True, test_needs_autoincrement=True)
+        )
+        Table("base_mtom", metadata,
+            Column('base_id', Integer, ForeignKey('base.id'), primary_key=True),
+            Column('related_id', Integer, ForeignKey('related.id'), primary_key=True)
+        )
+        Table("sub_mtom", metadata,
+            Column('base_id', Integer, ForeignKey('sub.id'), primary_key=True),
+            Column('related_id', Integer, ForeignKey('related.id'), primary_key=True)
+        )
+        Table("related", metadata,
+            Column('id', Integer, primary_key=True, test_needs_autoincrement=True)
+        )
+        
+    @classmethod
+    @testing.resolve_artifact_names
+    def setup_classes(cls):
+        class Base(_base.ComparableEntity):
+            pass
+        class Sub(Base):
+            pass
+        class Related(_base.ComparableEntity):
+            pass
+
+    @testing.resolve_artifact_names
+    def test_selective_relations(self):
+        mapper(Base, base, properties={
+            'related':relation(Related, secondary=base_mtom, backref='bases', order_by=related.c.id)
+        })
+        mapper(Sub, sub, inherits=Base, concrete=True, properties={
+            'related':relation(Related, secondary=sub_mtom, backref='subs', order_by=related.c.id)
+        })
+        mapper(Related, related)
+        
+        sess = sessionmaker()()
+        
+        b1, s1, r1, r2, r3 = Base(), Sub(), Related(), Related(), Related()
+        
+        b1.related.append(r1)
+        b1.related.append(r2)
+        s1.related.append(r2)
+        s1.related.append(r3)
+        sess.add_all([b1, s1])
+        sess.commit()
         
+        eq_(s1.related, [r2, r3])
+        eq_(b1.related, [r1, r2])
     
 class ColKeysTest(_base.MappedTest):
     @classmethod