]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Raise for NULL discriminator and pk is present
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 3 Sep 2019 13:56:41 +0000 (09:56 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Wed, 4 Sep 2019 02:20:15 +0000 (22:20 -0400)
An exception is now raised if the ORM loads a row for a polymorphic
instance that has a primary key but the discriminator column is NULL, as
discriminator columns should not be null.

Fixes: #4836
Change-Id: Ice1a853a7dd7687c58079b9933f145b90d314236

doc/build/changelog/unreleased_14/4836.rst [new file with mode: 0644]
lib/sqlalchemy/orm/loading.py
test/orm/inheritance/test_basic.py
test/orm/inheritance/test_relationship.py

diff --git a/doc/build/changelog/unreleased_14/4836.rst b/doc/build/changelog/unreleased_14/4836.rst
new file mode 100644 (file)
index 0000000..5a4d64d
--- /dev/null
@@ -0,0 +1,9 @@
+.. change::
+    :tags: bug, orm
+    :tickets: 4836
+
+    An exception is now raised if the ORM loads a row for a polymorphic
+    instance that has a primary key but the discriminator column is NULL, as
+    discriminator columns should not be null.
+
+
index f07067d17d386f8f3de9a80e7f57ff0d2f55cae8..94a9b8d2205c04c72f567c325edfc7a0c0c62942 100644 (file)
@@ -633,6 +633,20 @@ def _instance_processor(
     if mapper.polymorphic_map and not _polymorphic_from and not refresh_state:
         # if we are doing polymorphic, dispatch to a different _instance()
         # method specific to the subclass mapper
+
+        def ensure_no_pk(row):
+            identitykey = (
+                identity_class,
+                tuple([row[column] for column in pk_cols]),
+                identity_token,
+            )
+            if not is_not_primary_key(identitykey[1]):
+                raise sa_exc.InvalidRequestError(
+                    "Row with identity key %s can't be loaded into an "
+                    "object; the polymorphic discriminator column '%s' is "
+                    "NULL" % (identitykey, polymorphic_discriminator)
+                )
+
         _instance = _decorate_polymorphic_switch(
             _instance,
             context,
@@ -641,6 +655,7 @@ def _instance_processor(
             path,
             polymorphic_discriminator,
             adapter,
+            ensure_no_pk,
         )
 
     return _instance
@@ -804,6 +819,7 @@ def _decorate_polymorphic_switch(
     path,
     polymorphic_discriminator,
     adapter,
+    ensure_no_pk,
 ):
     if polymorphic_discriminator is not None:
         polymorphic_on = polymorphic_discriminator
@@ -843,7 +859,11 @@ def _decorate_polymorphic_switch(
             _instance = polymorphic_instances[discriminator]
             if _instance:
                 return _instance(row)
-        return instance_fn(row)
+            else:
+                return instance_fn(row)
+        else:
+            ensure_no_pk(row)
+            return None
 
     return polymorphic_instance
 
index 9fce5891b3d394087d54810fe00d8f77e7e49425..57320d808f862dc9aaf2690bdc957e3cc18c6789 100644 (file)
@@ -3258,6 +3258,100 @@ class PolymorphicUnionTest(fixtures.TestBase, testing.AssertsCompiledSQL):
         )
 
 
+class DiscriminatorOrPkNoneTest(fixtures.DeclarativeMappedTest):
+
+    run_setup_mappers = "once"
+    __dialect__ = "default"
+
+    @classmethod
+    def setup_classes(cls):
+        Base = cls.DeclarativeBasic
+
+        class Parent(fixtures.ComparableEntity, Base):
+            __tablename__ = "parent"
+            id = Column(Integer, primary_key=True)
+
+        class A(fixtures.ComparableEntity, Base):
+            __tablename__ = "a"
+            id = Column(Integer, primary_key=True)
+            parent_id = Column(ForeignKey("parent.id"))
+            type = Column(String(50))
+            __mapper_args__ = {
+                "polymorphic_on": type,
+                "polymorphic_identity": "a",
+            }
+
+        class B(A):
+            __tablename__ = "b"
+            id = Column(ForeignKey("a.id"), primary_key=True)
+            __mapper_args__ = {"polymorphic_identity": "b"}
+
+    @classmethod
+    def insert_data(cls):
+        Parent, A, B = cls.classes("Parent", "A", "B")
+        s = Session()
+
+        p1 = Parent(id=1)
+        p2 = Parent(id=2)
+        s.add_all([p1, p2])
+        s.flush()
+
+        s.add_all(
+            [
+                A(id=1, parent_id=1),
+                B(id=2, parent_id=1),
+                A(id=3, parent_id=1),
+                B(id=4, parent_id=1),
+            ]
+        )
+        s.flush()
+
+        s.query(A).filter(A.id.in_([3, 4])).update(
+            {A.type: None}, synchronize_session=False
+        )
+        s.commit()
+
+    def test_pk_is_null(self):
+        Parent, A = self.classes("Parent", "A")
+
+        sess = Session()
+        q = (
+            sess.query(Parent, A)
+            .select_from(Parent)
+            .outerjoin(A)
+            .filter(Parent.id == 2)
+        )
+        row = q.all()[0]
+
+        eq_(row, (Parent(id=2), None))
+
+    def test_pk_not_null_discriminator_null_from_base(self):
+        A, = self.classes("A")
+
+        sess = Session()
+        q = sess.query(A).filter(A.id == 3)
+        assert_raises_message(
+            sa_exc.InvalidRequestError,
+            r"Row with identity key \(<class '.*A'>, \(3,\), None\) can't be "
+            "loaded into an object; the polymorphic discriminator "
+            "column 'a.type' is NULL",
+            q.all,
+        )
+
+    def test_pk_not_null_discriminator_null_from_sub(self):
+        B, = self.classes("B")
+
+        sess = Session()
+        q = sess.query(B).filter(B.id == 4)
+        assert_raises_message(
+            sa_exc.InvalidRequestError,
+            r"Row with identity key \(<class '.*A'>, \(4,\), None\) can't be "
+            "loaded into an object; the polymorphic discriminator "
+            "column 'a.type' is NULL",
+            q.all,
+        )
+
+
 class NameConflictTest(fixtures.MappedTest):
     @classmethod
     def define_tables(cls, metadata):
index 93396c73b02cfb9854ab49b4555e81617797f6e8..a96c5ef04af4b16af8dd1876f8e9fed69c0402ca 100644 (file)
@@ -1821,6 +1821,7 @@ class JoinedloadOverWPolyAliased(
 
             __mapper_args__ = {
                 "polymorphic_on": type,
+                "polymorphic_identity": "parent",
                 "with_polymorphic": ("*", None),
             }