From: Mike Bayer Date: Tue, 21 Jul 2009 21:47:03 +0000 (+0000) Subject: - relations() now have greater ability to be "overridden", X-Git-Tag: rel_0_5_6~38 X-Git-Url: http://git.ipfire.org/gitweb/gitweb.cgi?a=commitdiff_plain;h=b9b62b2369e00be2f344dd96aec94e88c9210fb0;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git - relations() now have greater ability to be "overridden", meaning a subclass that explicitly specifies a relation() overriding that of the parent class will be honored during a flush. This is currently to support many-to-many relations from concrete inheritance setups. Outside of that use case, YMMV. [ticket:1477] --- diff --git a/CHANGES b/CHANGES index 8fb09fc2d3..e4acdcfeec 100644 --- a/CHANGES +++ b/CHANGES @@ -9,6 +9,13 @@ CHANGES - Fixed bug whereby inheritance discriminator part of a composite primary key would fail on updates. Continuation of [ticket:1300]. + + - relations() now have greater ability to be "overridden", + meaning a subclass that explicitly specifies a relation() + overriding that of the parent class will be honored + during a flush. This is currently to support + many-to-many relations from concrete inheritance setups. + Outside of that use case, YMMV. [ticket:1477] - sql - Fixed a bug in extract() introduced in 0.5.4 whereby diff --git a/lib/sqlalchemy/orm/unitofwork.py b/lib/sqlalchemy/orm/unitofwork.py index da26c8d7b3..ef5b9fc1ab 100644 --- a/lib/sqlalchemy/orm/unitofwork.py +++ b/lib/sqlalchemy/orm/unitofwork.py @@ -430,6 +430,15 @@ class UOWTask(object): yield rec return collection + def _polymorphic_collection_filtered(fn): + + def collection(self, mappers): + for task in self.polymorphic_tasks: + if task.mapper in mappers: + for rec in fn(task): + yield rec + return collection + @property def elements(self): return self._objects.values() @@ -438,6 +447,10 @@ class UOWTask(object): def polymorphic_elements(self): return self.elements + @_polymorphic_collection_filtered + def filter_polymorphic_elements(self): + return self.elements + @property def polymorphic_tosave_elements(self): return [rec for rec in self.polymorphic_elements if not rec.isdelete] @@ -642,7 +655,19 @@ class UOWDependencyProcessor(object): def __init__(self, processor, targettask): self.processor = processor self.targettask = targettask - + prop = processor.prop + + # define a set of mappers which + # will filter the lists of entities + # this UOWDP processes. this allows + # MapperProperties to be overridden + # at least for concrete mappers. + self._mappers = set([ + m + for m in self.processor.parent.polymorphic_iterator() + if m._props[prop.key] is prop + ]).union(self.processor.mapper.polymorphic_iterator()) + def __repr__(self): return "UOWDependencyProcessor(%s, %s)" % (str(self.processor), str(self.targettask)) @@ -673,12 +698,16 @@ class UOWDependencyProcessor(object): return elem.state ret = False - elements = [getobj(elem) for elem in self.targettask.polymorphic_tosave_elements if self not in elem.preprocessed] + elements = [getobj(elem) for elem in + self.targettask.filter_polymorphic_elements(self._mappers) + if self not in elem.preprocessed and not elem.isdelete] if elements: ret = True self.processor.preprocess_dependencies(self.targettask, elements, trans, delete=False) - elements = [getobj(elem) for elem in self.targettask.polymorphic_todelete_elements if self not in elem.preprocessed] + elements = [getobj(elem) for elem in + self.targettask.filter_polymorphic_elements(self._mappers) + if self not in elem.preprocessed and elem.isdelete] if elements: ret = True self.processor.preprocess_dependencies(self.targettask, elements, trans, delete=True) @@ -687,10 +716,10 @@ class UOWDependencyProcessor(object): def execute(self, trans, delete): """process all objects contained within this ``UOWDependencyProcessor``s target task.""" - if delete: - elements = self.targettask.polymorphic_todelete_elements - else: - elements = self.targettask.polymorphic_tosave_elements + + elements = [e for e in + self.targettask.filter_polymorphic_elements(self._mappers) + if e.isdelete==delete] self.processor.process_dependencies( self.targettask, diff --git a/test/orm/inheritance/test_concrete.py b/test/orm/inheritance/test_concrete.py index 4a884cb86c..46bd171e44 100644 --- a/test/orm/inheritance/test_concrete.py +++ b/test/orm/inheritance/test_concrete.py @@ -464,7 +464,61 @@ class PropertyInheritanceTest(_base.MappedTest): sess.query(C).options(eagerload(C.many_a)).order_by(C.id).all(), ) self.assert_sql_count(testing.db, go, 1) + +class ManyToManyTest(_base.MappedTest): + @classmethod + def define_tables(cls, metadata): + Table("base", metadata, + Column('id', Integer, primary_key=True, test_needs_autoincrement=True) + ) + Table("sub", metadata, + Column('id', Integer, primary_key=True, test_needs_autoincrement=True) + ) + Table("base_mtom", metadata, + Column('base_id', Integer, ForeignKey('base.id'), primary_key=True), + Column('related_id', Integer, ForeignKey('related.id'), primary_key=True) + ) + Table("sub_mtom", metadata, + Column('base_id', Integer, ForeignKey('sub.id'), primary_key=True), + Column('related_id', Integer, ForeignKey('related.id'), primary_key=True) + ) + Table("related", metadata, + Column('id', Integer, primary_key=True, test_needs_autoincrement=True) + ) + + @classmethod + @testing.resolve_artifact_names + def setup_classes(cls): + class Base(_base.ComparableEntity): + pass + class Sub(Base): + pass + class Related(_base.ComparableEntity): + pass + + @testing.resolve_artifact_names + def test_selective_relations(self): + mapper(Base, base, properties={ + 'related':relation(Related, secondary=base_mtom, backref='bases', order_by=related.c.id) + }) + mapper(Sub, sub, inherits=Base, concrete=True, properties={ + 'related':relation(Related, secondary=sub_mtom, backref='subs', order_by=related.c.id) + }) + mapper(Related, related) + + sess = sessionmaker()() + + b1, s1, r1, r2, r3 = Base(), Sub(), Related(), Related(), Related() + + b1.related.append(r1) + b1.related.append(r2) + s1.related.append(r2) + s1.related.append(r3) + sess.add_all([b1, s1]) + sess.commit() + eq_(s1.related, [r2, r3]) + eq_(b1.related, [r1, r2]) class ColKeysTest(_base.MappedTest): @classmethod