]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- [bug] Fixed bug whereby a subclass of a subclass
authorMike Bayer <mike_mp@zzzcomputing.com>
Fri, 28 Oct 2011 21:46:28 +0000 (17:46 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 28 Oct 2011 21:46:28 +0000 (17:46 -0400)
using concrete inheritance in conjunction with
the new ConcreteBase or AbstractConcreteBase
would fail to apply the subclasses deeper than
one level to the "polymorphic loader" of each
base  [ticket:2312]

- [bug] Fixed bug whereby a subclass of a subclass
using the new AbstractConcreteBase would fail
to acquire the correct "base_mapper" attribute
when the "base" mapper was generated, thereby
causing failures later on.  [ticket:2312]

CHANGES
lib/sqlalchemy/ext/declarative.py
lib/sqlalchemy/orm/mapper.py
test/ext/test_declarative.py

diff --git a/CHANGES b/CHANGES
index 7a323fe2c90bfbab6af38edbadd04c54d30f4409..a2a3f2eb7870510f32ccfa662e5901617cbd6522 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -38,6 +38,19 @@ CHANGES
      dictionary is up to date, fixes [ticket:2308].
      Thanks to Scott Torborg for the test case here.
 
+   - [bug] Fixed bug whereby a subclass of a subclass
+     using concrete inheritance in conjunction with
+     the new ConcreteBase or AbstractConcreteBase
+     would fail to apply the subclasses deeper than
+     one level to the "polymorphic loader" of each
+     base  [ticket:2312]
+
+   - [bug] Fixed bug whereby a subclass of a subclass
+     using the new AbstractConcreteBase would fail
+     to acquire the correct "base_mapper" attribute
+     when the "base" mapper was generated, thereby
+     causing failures later on.  [ticket:2312]
+
 - sql
    - [feature] Added accessor to types called "python_type",
      returns the rudimentary Python type object
index acfc09396aac938a4b0a2e8222c1bb45df06f878..1f082adf14ac3359f3e9409069956909abf28359 100755 (executable)
@@ -1616,10 +1616,8 @@ class ConcreteBase(object):
         m = cls.__mapper__
         if m.with_polymorphic:
             return
-        mappers = [  sm for sm in [
-                    _mapper_or_none(klass)
-                    for klass in cls.__subclasses__()
-                ] if sm is not None] + [m]
+
+        mappers = list(m.self_and_descendants)
         pjoin = cls._create_polymorphic_union(mappers)
         m._set_with_polymorphic(("*",pjoin))
         m._set_polymorphic_on(pjoin.c.type)
@@ -1661,13 +1659,22 @@ class AbstractConcreteBase(ConcreteBase):
     def __declare_last__(cls):
         if hasattr(cls, '__mapper__'):
             return
-        table = cls._create_polymorphic_union(
-            m for m in [
-                _mapper_or_none(klass)
-                for klass in cls.__subclasses__()
-            ] if m is not None
-        )
-        cls.__mapper__ = m = mapper(cls, table, polymorphic_on=table.c.type)
+
+        # can't rely on 'self_and_descendants' here
+        # since technically an immediate subclass
+        # might not be mapped, but a subclass
+        # may be.
+        mappers = []
+        stack = list(cls.__subclasses__())
+        while stack:
+            klass = stack.pop()
+            stack.extend(klass.__subclasses__())
+            mn = _mapper_or_none(klass)
+            if mn is not None:
+                mappers.append(mn)
+        pjoin = cls._create_polymorphic_union(mappers)
+        cls.__mapper__ = m = mapper(cls, pjoin, polymorphic_on=pjoin.c.type)
+
         for scls in cls.__subclasses__():
             sm = _mapper_or_none(scls)
             if sm.concrete and cls in scls.__bases__:
index e6ec422b068750e390c694dea780c45f539c4296..3a363e73133ce441f3aa734376a7479d345ab9ec 100644 (file)
@@ -577,7 +577,8 @@ class Mapper(object):
             if mapper.polymorphic_on is not None:
                 mapper._requires_row_aliasing = True
         self.batch = self.inherits.batch
-        self.base_mapper = self.inherits.base_mapper
+        for mp in self.self_and_descendants:
+            mp.base_mapper = self.inherits.base_mapper
         self.inherits._inheriting_mappers.add(self)
         self.passive_updates = self.inherits.passive_updates
         self._all_tables = self.inherits._all_tables
index ec59fb86b02b59d0a10563df07a9fb6b7a4ab37f..af7f4e5c44b1765d7911e3caa535536885897ab5 100644 (file)
@@ -2037,25 +2037,27 @@ class DeclarativeInheritanceTest(DeclarativeTestBase):
 
 from test.orm.test_events import _RemoveListeners
 class ConcreteInhTest(_RemoveListeners, DeclarativeTestBase):
-    def _roundtrip(self, Employee, Manager, Engineer, polymorphic=True):
+    def _roundtrip(self, Employee, Manager, Engineer, Boss, polymorphic=True):
         Base.metadata.create_all()
         sess = create_session()
         e1 = Engineer(name='dilbert', primary_language='java')
         e2 = Engineer(name='wally', primary_language='c++')
         m1 = Manager(name='dogbert', golf_swing='fore!')
         e3 = Engineer(name='vlad', primary_language='cobol')
-        sess.add_all([e1, e2, m1, e3])
+        b1 = Boss(name="pointy haired")
+        sess.add_all([e1, e2, m1, e3, b1])
         sess.flush()
         sess.expunge_all()
         if polymorphic:
             eq_(sess.query(Employee).order_by(Employee.name).all(),
                 [Engineer(name='dilbert'), Manager(name='dogbert'),
-                Engineer(name='vlad'), Engineer(name='wally')])
+                Boss(name='pointy haired'), Engineer(name='vlad'), Engineer(name='wally')])
         else:
             eq_(sess.query(Engineer).order_by(Engineer.name).all(),
                 [Engineer(name='dilbert'), Engineer(name='vlad'),
                 Engineer(name='wally')])
             eq_(sess.query(Manager).all(), [Manager(name='dogbert')])
+            eq_(sess.query(Boss).all(), [Boss(name='pointy haired')])
 
 
     def test_explicit(self):
@@ -2064,12 +2066,20 @@ class ConcreteInhTest(_RemoveListeners, DeclarativeTestBase):
                           test_needs_autoincrement=True), Column('name'
                           , String(50)), Column('primary_language',
                           String(50)))
-        managers = Table('managers', Base.metadata, Column('id',
-                         Integer, primary_key=True,
-                         test_needs_autoincrement=True), Column('name',
-                         String(50)), Column('golf_swing', String(50)))
-        punion = polymorphic_union({'engineer': engineers, 'manager'
-                                   : managers}, 'type', 'punion')
+        managers = Table('managers', Base.metadata, 
+                    Column('id',Integer, primary_key=True, test_needs_autoincrement=True), 
+                    Column('name', String(50)), 
+                    Column('golf_swing', String(50))
+                )
+        boss = Table('boss', Base.metadata, 
+                    Column('id',Integer, primary_key=True, test_needs_autoincrement=True), 
+                    Column('name', String(50)), 
+                    Column('golf_swing', String(50))
+                )
+        punion = polymorphic_union({
+                                'engineer': engineers, 
+                                'manager' : managers,
+                                'boss': boss}, 'type', 'punion')
 
         class Employee(Base, fixtures.ComparableEntity):
 
@@ -2087,7 +2097,13 @@ class ConcreteInhTest(_RemoveListeners, DeclarativeTestBase):
             __table__ = managers
             __mapper_args__ = {'polymorphic_identity': 'manager',
                                'concrete': True}
-        self._roundtrip(Employee, Manager, Engineer)
+
+        class Boss(Manager):
+            __table__ = boss
+            __mapper_args__ = {'polymorphic_identity': 'boss',
+                               'concrete': True}
+
+        self._roundtrip(Employee, Manager, Engineer, Boss)
 
     def test_concrete_inline_non_polymorphic(self):
         """test the example from the declarative docs."""
@@ -2116,7 +2132,16 @@ class ConcreteInhTest(_RemoveListeners, DeclarativeTestBase):
                         test_needs_autoincrement=True)
             golf_swing = Column(String(50))
             name = Column(String(50))
-        self._roundtrip(Employee, Manager, Engineer, polymorphic=False)
+
+        class Boss(Manager):
+            __tablename__ = 'boss'
+            __mapper_args__ = {'concrete': True}
+            id = Column(Integer, primary_key=True,
+                        test_needs_autoincrement=True)
+            golf_swing = Column(String(50))
+            name = Column(String(50))
+
+        self._roundtrip(Employee, Manager, Engineer, Boss, polymorphic=False)
 
     def test_abstract_concrete_extension(self):
         class Employee(AbstractConcreteBase, Base, fixtures.ComparableEntity):
@@ -2132,6 +2157,16 @@ class ConcreteInhTest(_RemoveListeners, DeclarativeTestBase):
                             'polymorphic_identity':'manager', 
                             'concrete':True}
 
+        class Boss(Manager):
+            __tablename__ = 'boss'
+            employee_id = Column(Integer, primary_key=True, 
+                                    test_needs_autoincrement=True)
+            name = Column(String(50))
+            golf_swing = Column(String(40))
+            __mapper_args__ = {
+                            'polymorphic_identity':'boss', 
+                            'concrete':True}
+
         class Engineer(Employee):
             __tablename__ = 'engineer'
             employee_id = Column(Integer, primary_key=True, 
@@ -2141,7 +2176,7 @@ class ConcreteInhTest(_RemoveListeners, DeclarativeTestBase):
             __mapper_args__ = {'polymorphic_identity':'engineer', 
                             'concrete':True}
 
-        self._roundtrip(Employee, Manager, Engineer)
+        self._roundtrip(Employee, Manager, Engineer, Boss)
 
     def test_concrete_extension(self):
         class Employee(ConcreteBase, Base, fixtures.ComparableEntity):
@@ -2162,6 +2197,16 @@ class ConcreteInhTest(_RemoveListeners, DeclarativeTestBase):
                             'polymorphic_identity':'manager', 
                             'concrete':True}
 
+        class Boss(Manager):
+            __tablename__ = 'boss'
+            employee_id = Column(Integer, primary_key=True, 
+                                    test_needs_autoincrement=True)
+            name = Column(String(50))
+            golf_swing = Column(String(40))
+            __mapper_args__ = {
+                            'polymorphic_identity':'boss', 
+                            'concrete':True}
+
         class Engineer(Employee):
             __tablename__ = 'engineer'
             employee_id = Column(Integer, primary_key=True, 
@@ -2170,7 +2215,7 @@ class ConcreteInhTest(_RemoveListeners, DeclarativeTestBase):
             primary_language = Column(String(40))
             __mapper_args__ = {'polymorphic_identity':'engineer', 
                             'concrete':True}
-        self._roundtrip(Employee, Manager, Engineer)
+        self._roundtrip(Employee, Manager, Engineer, Boss)
 
 
 def _produce_test(inline, stringbased):