]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- relations() now have greater ability to be "overridden",
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 21 Jul 2009 21:47:03 +0000 (21:47 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 21 Jul 2009 21:47:03 +0000 (21:47 +0000)
meaning a subclass that explicitly specifies a relation()
overriding that of the parent class will be honored
during a flush.  This is currently to support
many-to-many relations from concrete inheritance setups.
Outside of that use case, YMMV.  [ticket:1477]

CHANGES
lib/sqlalchemy/orm/unitofwork.py
test/orm/inheritance/test_concrete.py

diff --git a/CHANGES b/CHANGES
index 8fb09fc2d378f95efbd96ffba611e364285a76ea..e4acdcfeecf5d18c29ac8ce5041cb0f67e356f21 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -9,6 +9,13 @@ CHANGES
     - Fixed bug whereby inheritance discriminator part of a
       composite primary key would fail on updates.
       Continuation of [ticket:1300].
+    
+    - relations() now have greater ability to be "overridden",
+      meaning a subclass that explicitly specifies a relation()
+      overriding that of the parent class will be honored
+      during a flush.  This is currently to support 
+      many-to-many relations from concrete inheritance setups.
+      Outside of that use case, YMMV.  [ticket:1477]
       
 - sql
     - Fixed a bug in extract() introduced in 0.5.4 whereby
index da26c8d7b38f464a5017eb09e790167b9a7cd6e2..ef5b9fc1abcaebcb666307dae6538a83345050b3 100644 (file)
@@ -430,6 +430,15 @@ class UOWTask(object):
                     yield rec
         return collection
 
+    def _polymorphic_collection_filtered(fn):
+
+        def collection(self, mappers):
+            for task in self.polymorphic_tasks:
+                if task.mapper in mappers:
+                    for rec in fn(task):
+                        yield rec
+        return collection
+
     @property
     def elements(self):
         return self._objects.values()
@@ -438,6 +447,10 @@ class UOWTask(object):
     def polymorphic_elements(self):
         return self.elements
 
+    @_polymorphic_collection_filtered
+    def filter_polymorphic_elements(self):
+        return self.elements
+
     @property
     def polymorphic_tosave_elements(self):
         return [rec for rec in self.polymorphic_elements if not rec.isdelete]
@@ -642,7 +655,19 @@ class UOWDependencyProcessor(object):
     def __init__(self, processor, targettask):
         self.processor = processor
         self.targettask = targettask
-
+        prop = processor.prop
+        
+        # define a set of mappers which
+        # will filter the lists of entities
+        # this UOWDP processes.  this allows
+        # MapperProperties to be overridden
+        # at least for concrete mappers.
+        self._mappers = set([
+            m
+            for m in self.processor.parent.polymorphic_iterator()
+            if m._props[prop.key] is prop
+        ]).union(self.processor.mapper.polymorphic_iterator())
+            
     def __repr__(self):
         return "UOWDependencyProcessor(%s, %s)" % (str(self.processor), str(self.targettask))
 
@@ -673,12 +698,16 @@ class UOWDependencyProcessor(object):
             return elem.state
 
         ret = False
-        elements = [getobj(elem) for elem in self.targettask.polymorphic_tosave_elements if self not in elem.preprocessed]
+        elements = [getobj(elem) for elem in 
+                        self.targettask.filter_polymorphic_elements(self._mappers)
+                        if self not in elem.preprocessed and not elem.isdelete]
         if elements:
             ret = True
             self.processor.preprocess_dependencies(self.targettask, elements, trans, delete=False)
 
-        elements = [getobj(elem) for elem in self.targettask.polymorphic_todelete_elements if self not in elem.preprocessed]
+        elements = [getobj(elem) for elem in 
+                        self.targettask.filter_polymorphic_elements(self._mappers)
+                        if self not in elem.preprocessed and elem.isdelete]
         if elements:
             ret = True
             self.processor.preprocess_dependencies(self.targettask, elements, trans, delete=True)
@@ -687,10 +716,10 @@ class UOWDependencyProcessor(object):
     def execute(self, trans, delete):
         """process all objects contained within this ``UOWDependencyProcessor``s target task."""
 
-        if delete:
-            elements = self.targettask.polymorphic_todelete_elements
-        else:
-            elements = self.targettask.polymorphic_tosave_elements
+
+        elements = [e for e in 
+                    self.targettask.filter_polymorphic_elements(self._mappers) 
+                    if e.isdelete==delete]
 
         self.processor.process_dependencies(
             self.targettask, 
index 4a884cb86c71a72d8588ce406c70103e048c8e76..46bd171e4405286b7ffca8d226c2762559890dbb 100644 (file)
@@ -464,7 +464,61 @@ class PropertyInheritanceTest(_base.MappedTest):
                 sess.query(C).options(eagerload(C.many_a)).order_by(C.id).all(),
             )
         self.assert_sql_count(testing.db, go, 1)
+
+class ManyToManyTest(_base.MappedTest):
+    @classmethod
+    def define_tables(cls, metadata):
+        Table("base", metadata,
+            Column('id', Integer, primary_key=True, test_needs_autoincrement=True)
+        )
+        Table("sub", metadata,
+            Column('id', Integer, primary_key=True, test_needs_autoincrement=True)
+        )
+        Table("base_mtom", metadata,
+            Column('base_id', Integer, ForeignKey('base.id'), primary_key=True),
+            Column('related_id', Integer, ForeignKey('related.id'), primary_key=True)
+        )
+        Table("sub_mtom", metadata,
+            Column('base_id', Integer, ForeignKey('sub.id'), primary_key=True),
+            Column('related_id', Integer, ForeignKey('related.id'), primary_key=True)
+        )
+        Table("related", metadata,
+            Column('id', Integer, primary_key=True, test_needs_autoincrement=True)
+        )
+        
+    @classmethod
+    @testing.resolve_artifact_names
+    def setup_classes(cls):
+        class Base(_base.ComparableEntity):
+            pass
+        class Sub(Base):
+            pass
+        class Related(_base.ComparableEntity):
+            pass
+
+    @testing.resolve_artifact_names
+    def test_selective_relations(self):
+        mapper(Base, base, properties={
+            'related':relation(Related, secondary=base_mtom, backref='bases', order_by=related.c.id)
+        })
+        mapper(Sub, sub, inherits=Base, concrete=True, properties={
+            'related':relation(Related, secondary=sub_mtom, backref='subs', order_by=related.c.id)
+        })
+        mapper(Related, related)
+        
+        sess = sessionmaker()()
+        
+        b1, s1, r1, r2, r3 = Base(), Sub(), Related(), Related(), Related()
+        
+        b1.related.append(r1)
+        b1.related.append(r2)
+        s1.related.append(r2)
+        s1.related.append(r3)
+        sess.add_all([b1, s1])
+        sess.commit()
         
+        eq_(s1.related, [r2, r3])
+        eq_(b1.related, [r1, r2])
     
 class ColKeysTest(_base.MappedTest):
     @classmethod