]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- the column assigned to polymorphic_on now behaves like any other
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 20 Nov 2010 21:25:12 +0000 (16:25 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 20 Nov 2010 21:25:12 +0000 (16:25 -0500)
mapped attribute, in that it can be assigned to, mapped to multiple
columns.  It is also populated immediately upon object construction
with its class-based value, so the attribute can be read before
any flush occurs.  [ticket:1895]

lib/sqlalchemy/orm/mapper.py
test/orm/inheritance/test_basic.py

index 8fc0b4132121657e99749acf32f329e18e53aed6..32eb8f64333eac9b26072519970bd5c47dcdee9d 100644 (file)
@@ -208,6 +208,7 @@ class Mapper(object):
             self._configure_class_instrumentation()
             self._configure_listeners()
             self._configure_properties()
+            self._configure_polymorphic_setter()
             self._configure_pks()
             global _new_mappers
             _new_mappers = True
@@ -310,7 +311,7 @@ class Mapper(object):
             if self.polymorphic_identity is not None:
                 self.polymorphic_map[self.polymorphic_identity] = self
             self._identity_class = self.class_
-
+        
         if self.mapped_table is None:
             raise sa_exc.ArgumentError(
                     "Mapper '%s' does not have a mapped_table specified." 
@@ -560,33 +561,65 @@ class Mapper(object):
                                     init=False, 
                                     setparent=True)
 
+    def _configure_polymorphic_setter(self):
+        """Configure an attribute on the mapper representing the 
+        'polymorphic_on' column, if applicable, and not 
+        already generated by _configure_properties (which is typical).
+        
+        Also create a setter function which will assign this
+        attribute to the value of the 'polymorphic_identity'
+        upon instance construction, also if applicable.  This 
+        routine will run when an instance is created.
+        
+        """
         # do a special check for the "discriminiator" column, as it 
         # may only be present in the 'with_polymorphic' selectable 
         # but we need it for the base mapper
-        if self.polymorphic_on is not None and \
-                    self.polymorphic_on not in self._columntoproperty:
-            col = self.mapped_table.corresponding_column(self.polymorphic_on)
-            if col is None:
-                instrument = False
-                col = self.polymorphic_on
-                if self.with_polymorphic is None \
-                    or self.with_polymorphic[1].corresponding_column(col) \
-                    is None:
-                    util.warn("Could not map polymorphic_on column "
-                              "'%s' to the mapped table - polymorphic "
-                              "loads will not function properly"
-                              % col.description)
-            else:
-                instrument = True
-            if self._should_exclude(col.key, col.key, local=False, column=col):
-                raise sa_exc.InvalidRequestError(
-                    "Cannot exclude or override the discriminator column %r" %
-                    col.key)
+        setter = False
+        
+        if self.polymorphic_on is not None:
+            setter = True
+
+            if self.polymorphic_on not in self._columntoproperty:
+                col = self.mapped_table.corresponding_column(self.polymorphic_on)
+                if col is None:
+                    setter = False
+                    instrument = False
+                    col = self.polymorphic_on
+                    if self.with_polymorphic is None \
+                        or self.with_polymorphic[1].corresponding_column(col) \
+                        is None:
+                        util.warn("Could not map polymorphic_on column "
+                                  "'%s' to the mapped table - polymorphic "
+                                  "loads will not function properly"
+                                  % col.description)
+                else:
+                    instrument = True
+
+                if self._should_exclude(col.key, col.key, False, col):
+                    raise sa_exc.InvalidRequestError(
+                        "Cannot exclude or override the discriminator column %r" %
+                        col.key)
                     
-            self._configure_property(
-                            col.key, 
-                            properties.ColumnProperty(col, _instrument=instrument),
-                            init=False, setparent=True)
+                self._configure_property(
+                                col.key, 
+                                properties.ColumnProperty(col, _instrument=instrument),
+                                init=False, setparent=True)
+                polymorphic_key = col.key
+            else:
+                polymorphic_key = self._columntoproperty[self.polymorphic_on].key
+
+        if setter:
+            def _set_polymorphic_identity(state):
+                dict_ = state.dict
+                state.get_impl(polymorphic_key).set(state, dict_,
+                        self.polymorphic_identity, None)
+
+            self._set_polymorphic_identity = _set_polymorphic_identity
+        else:
+            self._set_polymorphic_identity = None
+        
+
 
     def _adapt_inherited_property(self, key, prop, init):
         if not self.concrete:
@@ -1656,13 +1689,6 @@ class Mapper(object):
                         if col is mapper.version_id_col:
                             params[col.key] = \
                               mapper.version_id_generator(None)
-                        elif mapper.polymorphic_on is not None and \
-                                mapper.polymorphic_on.shares_lineage(col):
-                            value = mapper.polymorphic_identity
-                            if ((col.default is None and
-                                 col.server_default is None) or
-                                value is not None):
-                                params[col.key] = value
                         elif col in pks:
                             value = \
                              mapper._get_state_attr_by_column(
@@ -1715,10 +1741,6 @@ class Mapper(object):
                                                     passive=True)
                                     if history.added:
                                         hasdata = True
-                        elif mapper.polymorphic_on is not None and \
-                            mapper.polymorphic_on.shares_lineage(col) and \
-                                col not in pks:
-                            pass
                         else:
                             prop = mapper._columntoproperty[col]
                             history = attributes.get_state_history(
@@ -1890,11 +1912,6 @@ class Mapper(object):
         postfetch_cols = resultproxy.postfetch_cols()
         generated_cols = list(resultproxy.prefetch_cols())
 
-        if self.polymorphic_on is not None:
-            po = table.corresponding_column(self.polymorphic_on)
-            if po is not None:
-                generated_cols.append(po)
-
         if self.version_id_col is not None:
             generated_cols.append(self.version_id_col)
 
@@ -2390,6 +2407,9 @@ def _event_on_init(state, args, kwargs):
     if instrumenting_mapper:
         # compile() always compiles all mappers
         instrumenting_mapper.compile()
+        
+        if instrumenting_mapper._set_polymorphic_identity:
+            instrumenting_mapper._set_polymorphic_identity(state)
 
 def _event_on_resurrect(state):
     # re-populate the primary key elements
index 3e4d50d96b36701d2af28b42c79a5a2e8ff0aafa..896c7618f09e5f45b6ed63f24675a5a22ccdbf2f 100644 (file)
@@ -201,6 +201,75 @@ class PolymorphicSynonymTest(_base.MappedTest):
         eq_(sess.query(T2).filter(T2.info=='at2').one(), at2)
         eq_(at2.info, "THE INFO IS:at2")
         
+class PolymorphicAttributeManagementTest(_base.MappedTest):
+    """Test polymorphic_on can be assigned, can be mirrored, etc."""
+
+    run_setup_mappers = 'once'
+    
+    @classmethod
+    def define_tables(cls, metadata):
+        Table('table_a', metadata,
+            Column('id', Integer, primary_key=True),
+            Column('class_name', String(50))
+        )
+        Table('table_b', metadata,
+           Column('id', Integer, ForeignKey('table_a.id'), primary_key=True),
+           Column('class_name', String(50))
+        )
+        Table('table_c', metadata,
+           Column('id', Integer, ForeignKey('table_b.id'),primary_key=True)
+        )
+    
+    @classmethod
+    @testing.resolve_artifact_names
+    def setup_classes(cls):
+        class A(_base.ComparableEntity):
+            pass
+        class B(A):
+            pass
+        class C(B):
+            pass
+        
+        mapper(A, table_a, 
+                        polymorphic_on=table_a.c.class_name, 
+                        polymorphic_identity='a')
+        mapper(B, table_b, inherits=A, 
+                        polymorphic_on=table_b.c.class_name, 
+                        polymorphic_identity='b')
+        mapper(C, table_c, inherits=B, 
+                        polymorphic_identity='c')
+    
+    @testing.resolve_artifact_names
+    def test_poly_configured_immediate(self):
+        a = A()
+        b = B()
+        c = C()
+        eq_(a.class_name, 'a')
+        eq_(b.class_name, 'b')
+        eq_(c.class_name, 'c')
+        
+    @testing.resolve_artifact_names
+    def test_base_class(self):
+        sess = Session()
+        c1 = C()
+        sess.add(c1)
+        sess.commit()
+
+        assert isinstance(sess.query(B).first(), C)
+
+        sess.close()
+
+        assert isinstance(sess.query(A).first(), C)
+
+    @testing.resolve_artifact_names
+    def test_assignment(self):
+        sess = Session()
+        b1 = B()
+        b1.class_name = 'c'
+        sess.add(b1)
+        sess.commit()
+        sess.close()
+        assert isinstance(sess.query(B).first(), C)
     
 class CascadeTest(_base.MappedTest):
     """that cascades on polymorphic relationships continue