From: Mike Bayer Date: Wed, 1 Aug 2012 04:08:55 +0000 (-0400) Subject: - [bug] with_polymorphic() produces JOINs X-Git-Tag: rel_0_8_0b1~295 X-Git-Url: http://git.ipfire.org/gitweb.cgi?a=commitdiff_plain;h=b99480aac23fdd4e0ef3633703ca5b32ea49805a;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git - [bug] with_polymorphic() produces JOINs in the correct order and with correct inheriting tables in the case of sending multi-level subclasses in an arbitrary order or with intermediary classes missing. [ticket:1900] --- diff --git a/CHANGES b/CHANGES index 14216fdc1d..c654fef43f 100644 --- a/CHANGES +++ b/CHANGES @@ -81,6 +81,12 @@ underneath "0.7.xx". contains_eager() [ticket:2438] [ticket:1106] + - [bug] with_polymorphic() produces JOINs + in the correct order and with correct inheriting + tables in the case of sending multi-level + subclasses in an arbitrary order or with + intermediary classes missing. [ticket:1900] + - [feature] The "deferred declarative reflection" system has been moved into the declarative extension itself, using the diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index 4ff26cc212..0254933c29 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -1347,12 +1347,19 @@ class Mapper(_InspectionAttr): if spec == '*': mappers = list(self.self_and_descendants) elif spec: - mappers = [_class_to_mapper(m) for m in util.to_list(spec)] - for m in mappers: + mappers = set() + for m in util.to_list(spec): + m = _class_to_mapper(m) if not m.isa(self): raise sa_exc.InvalidRequestError( "%r does not inherit from %r" % (m, self)) + + if selectable is None: + mappers.update(m.iterate_to_root()) + else: + mappers.add(m) + mappers = [m for m in self.self_and_descendants if m in mappers] else: mappers = [] diff --git a/test/orm/inheritance/test_basic.py b/test/orm/inheritance/test_basic.py index 6c3a34e3ef..cb4b3ea888 100644 --- a/test/orm/inheritance/test_basic.py +++ b/test/orm/inheritance/test_basic.py @@ -10,6 +10,7 @@ from test.lib import testing, engines from test.lib import fixtures from test.orm import _fixtures from test.lib.schema import Table, Column +from sqlalchemy import inspect from sqlalchemy.ext.declarative import declarative_base from test.lib.util import gc_collect @@ -74,6 +75,104 @@ class O2MTest(fixtures.MappedTest): eq_(l[0].parent_foo.data, 'foo #1') eq_(l[1].parent_foo.data, 'foo #1') +class PolymorphicResolutionMultiLevel(fixtures.DeclarativeMappedTest, + testing.AssertsCompiledSQL): + run_setup_mappers = 'once' + __dialect__ = 'default' + + @classmethod + def setup_classes(cls): + Base = cls.DeclarativeBasic + class A(Base): + __tablename__ = 'a' + id = Column(Integer, primary_key=True) + class B(A): + __tablename__ = 'b' + id = Column(Integer, ForeignKey('a.id'), primary_key=True) + class C(A): + __tablename__ = 'c' + id = Column(Integer, ForeignKey('a.id'), primary_key=True) + class D(B): + __tablename__ = 'd' + id = Column(Integer, ForeignKey('b.id'), primary_key=True) + + def test_ordered_b_d(self): + a_mapper = inspect(self.classes.A) + eq_( + a_mapper._mappers_from_spec( + [self.classes.B, self.classes.D], None), + [a_mapper, inspect(self.classes.B), inspect(self.classes.D)] + ) + + def test_a(self): + a_mapper = inspect(self.classes.A) + eq_( + a_mapper._mappers_from_spec( + [self.classes.A], None), + [a_mapper] + ) + + def test_b_d_selectable(self): + a_mapper = inspect(self.classes.A) + spec = [self.classes.D, self.classes.B] + eq_( + a_mapper._mappers_from_spec( + spec, + self.classes.B.__table__.join(self.classes.D.__table__) + ), + [inspect(self.classes.B), inspect(self.classes.D)] + ) + + def test_d_selectable(self): + a_mapper = inspect(self.classes.A) + spec = [self.classes.D] + eq_( + a_mapper._mappers_from_spec( + spec, + self.classes.B.__table__.join(self.classes.D.__table__) + ), + [inspect(self.classes.D)] + ) + + def test_reverse_d_b(self): + a_mapper = inspect(self.classes.A) + spec = [self.classes.D, self.classes.B] + eq_( + a_mapper._mappers_from_spec( + spec, None), + [a_mapper, inspect(self.classes.B), inspect(self.classes.D)] + ) + mappers, selectable = a_mapper._with_polymorphic_args(spec=spec) + self.assert_compile(selectable, + "a LEFT OUTER JOIN b ON a.id = b.id " + "LEFT OUTER JOIN d ON b.id = d.id") + + def test_d_b_missing(self): + a_mapper = inspect(self.classes.A) + spec = [self.classes.D] + eq_( + a_mapper._mappers_from_spec( + spec, None), + [a_mapper, inspect(self.classes.B), inspect(self.classes.D)] + ) + mappers, selectable = a_mapper._with_polymorphic_args(spec=spec) + self.assert_compile(selectable, + "a LEFT OUTER JOIN b ON a.id = b.id " + "LEFT OUTER JOIN d ON b.id = d.id") + + def test_d_c_b(self): + a_mapper = inspect(self.classes.A) + spec = [self.classes.D, self.classes.C, self.classes.B] + ms = a_mapper._mappers_from_spec(spec, None) + + eq_( + ms[-1], inspect(self.classes.D) + ) + eq_(ms[0], a_mapper) + eq_( + set(ms[1:3]), set(a_mapper._inheriting_mappers) + ) + class PolymorphicOnNotLocalTest(fixtures.MappedTest): @classmethod def define_tables(cls, metadata):