]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
A warning is emitted when trying to flush an object of an inherited
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 30 Jun 2013 15:09:37 +0000 (11:09 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 30 Jun 2013 15:11:35 +0000 (11:11 -0400)
mapped class where the polymorphic discriminator has been assigned
to a value that is invalid for the class.   [ticket:2750]

doc/build/changelog/changelog_08.rst
lib/sqlalchemy/orm/mapper.py
lib/sqlalchemy/orm/persistence.py
test/orm/inheritance/test_basic.py

index 442b7a2d9c4e4e44e82a0d55f5c33ee592bd85d6..c0e430ad66fd8e7539aa7565d8350f68b880d47e 100644 (file)
@@ -6,6 +6,14 @@
 .. changelog::
     :version: 0.8.2
 
+    .. change::
+        :tags: bug, orm
+        :tickets: 2750
+
+        A warning is emitted when trying to flush an object of an inherited
+        class where the polymorphic discriminator has been assigned
+        to a value that is invalid for the class.
+
     .. change::
         :tags: bug, postgresql
         :tickets: 2740
index c08d91b570cdc566357f17dab520a18b38306195..5d35d1ca84c4125c835d7304e04ca094fe5aa03a 100644 (file)
@@ -26,8 +26,8 @@ from . import instrumentation, attributes, \
 from .interfaces import MapperProperty, _InspectionAttr, _MappedAttribute
 
 from .util import _INSTRUMENTOR, _class_to_mapper, \
-     _state_mapper, class_mapper, \
-     PathRegistry
+        _state_mapper, class_mapper, \
+        PathRegistry, state_str
 import sys
 properties = util.importlater("sqlalchemy.orm", "properties")
 descriptor_props = util.importlater("sqlalchemy.orm", "descriptor_props")
@@ -1039,6 +1039,8 @@ class Mapper(_InspectionAttr):
                     if self.polymorphic_on is not None:
                         self._set_polymorphic_identity = \
                             mapper._set_polymorphic_identity
+                        self._validate_polymorphic_identity = \
+                            mapper._validate_polymorphic_identity
                     else:
                         self._set_polymorphic_identity = None
                     return
@@ -1049,10 +1051,39 @@ class Mapper(_InspectionAttr):
                 state.get_impl(polymorphic_key).set(state, dict_,
                         state.manager.mapper.polymorphic_identity, None)
 
+            def _validate_polymorphic_identity(mapper, state, dict_):
+                if dict_[polymorphic_key] not in \
+                    mapper._acceptable_polymorphic_identities:
+                    util.warn(
+                                "Flushing object %s with "
+                                "incompatible polymorphic identity %r; the "
+                                "object may not refresh and/or load correctly" % (
+                                        state_str(state),
+                                        dict_[polymorphic_key]
+                                    )
+                            )
+
             self._set_polymorphic_identity = _set_polymorphic_identity
+            self._validate_polymorphic_identity = _validate_polymorphic_identity
         else:
             self._set_polymorphic_identity = None
 
+
+    _validate_polymorphic_identity = None
+
+    @_memoized_configured_property
+    def _acceptable_polymorphic_identities(self):
+        identities = set()
+
+        stack = deque([self])
+        while stack:
+            item = stack.popleft()
+            if item.mapped_table is self.mapped_table:
+                identities.add(item.polymorphic_identity)
+                stack.extend(item._inheriting_mappers)
+
+        return identities
+
     def _adapt_inherited_property(self, key, prop, init):
         if not self.concrete:
             self._configure_property(key, prop, init=False, setparent=False)
index 5bc4739c1577e73ee619d00a01e949f75d20136c..46b0be947ab1108c5e639e4e58a5e08a0d7f90de 100644 (file)
@@ -150,6 +150,9 @@ def _organize_states_for_save(base_mapper, states, uowtransaction):
         else:
             mapper.dispatch.before_update(mapper, connection, state)
 
+        if mapper._validate_polymorphic_identity:
+            mapper._validate_polymorphic_identity(mapper, state, dict_)
+
         # detect if we have a "pending" instance (i.e. has
         # no instance_key attached to it), and another instance
         # with the same identity key already exists as persistent.
index bbfa543830163e033b807f593a7c0ec69d6bd3f5..f313a34ffd29e4af97952e3deb6e2bc5b59395af 100644 (file)
@@ -3,6 +3,7 @@ from sqlalchemy.testing import eq_, assert_raises, assert_raises_message
 from sqlalchemy import *
 from sqlalchemy import exc as sa_exc, util, event
 from sqlalchemy.orm import *
+from sqlalchemy.orm.util import instance_str
 from sqlalchemy.orm import exc as orm_exc, attributes
 from sqlalchemy.testing.assertsql import AllOf, CompiledSQL
 from sqlalchemy.sql import table, column
@@ -573,6 +574,8 @@ class PolymorphicAttributeManagementTest(fixtures.MappedTest):
             pass
         class C(B):
             pass
+        class D(B):
+            pass
 
         mapper(A, table_a,
                         polymorphic_on=table_a.c.class_name,
@@ -582,6 +585,8 @@ class PolymorphicAttributeManagementTest(fixtures.MappedTest):
                         polymorphic_identity='b')
         mapper(C, table_c, inherits=B,
                         polymorphic_identity='c')
+        mapper(D, inherits=B,
+                        polymorphic_identity='d')
 
     def test_poly_configured_immediate(self):
         A, C, B = (self.classes.A,
@@ -612,15 +617,30 @@ class PolymorphicAttributeManagementTest(fixtures.MappedTest):
         assert isinstance(sess.query(A).first(), C)
 
     def test_assignment(self):
-        C, B = self.classes.C, self.classes.B
+        D, B = self.classes.D, self.classes.B
 
         sess = Session()
         b1 = B()
-        b1.class_name = 'c'
+        b1.class_name = 'd'
         sess.add(b1)
         sess.commit()
         sess.close()
-        assert isinstance(sess.query(B).first(), C)
+        assert isinstance(sess.query(B).first(), D)
+
+    def test_invalid_assignment(self):
+        C, B = self.classes.C, self.classes.B
+
+        sess = Session()
+        c1 = C()
+        c1.class_name = 'b'
+        sess.add(c1)
+        assert_raises_message(
+            sa_exc.SAWarning,
+            "Flushing object %s with incompatible "
+            "polymorphic identity 'b'; the object may not "
+            "refresh and/or load correctly" % instance_str(c1),
+            sess.flush
+        )
 
 class CascadeTest(fixtures.MappedTest):
     """that cascades on polymorphic relationships continue