]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
handle polymorphic_discriminator in query_expression
authorMike Bayer <mike_mp@zzzcomputing.com>
Thu, 18 Dec 2025 16:15:46 +0000 (11:15 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 19 Dec 2025 19:38:18 +0000 (14:38 -0500)
Added support for using :func:`_orm.with_expression` to populate a
:func:`_orm.query_expression` attribute that is also configured as the
``polymorphic_on`` discriminator column. The ORM now detects when a query
expression column is serving as the polymorphic discriminator and updates
it to use the column provided via :func:`_orm.with_expression`, enabling
polymorphic loading to work correctly in this scenario. This allows for
patterns such as where the discriminator value is computed from a related
table.

Fixes: #12631
Change-Id: I20baf4cddc5a19664bf73764f9371b187686af68

doc/build/changelog/unreleased_21/12631.rst [new file with mode: 0644]
lib/sqlalchemy/orm/strategies.py
test/orm/inheritance/test_assorted_poly.py

diff --git a/doc/build/changelog/unreleased_21/12631.rst b/doc/build/changelog/unreleased_21/12631.rst
new file mode 100644 (file)
index 0000000..abdf6b7
--- /dev/null
@@ -0,0 +1,12 @@
+.. change::
+    :tags: usecase, orm
+    :tickets: 12631
+
+    Added support for using :func:`_orm.with_expression` to populate a
+    :func:`_orm.query_expression` attribute that is also configured as the
+    ``polymorphic_on`` discriminator column. The ORM now detects when a query
+    expression column is serving as the polymorphic discriminator and updates
+    it to use the column provided via :func:`_orm.with_expression`, enabling
+    polymorphic loading to work correctly in this scenario. This allows for
+    patterns such as where the discriminator value is computed from a related
+    table.
index 6a71316646916b5ed55cd083007a47dea834ce51..e636ef7dd518d96d28b3703d07c9bab11e709189 100644 (file)
@@ -338,6 +338,17 @@ class _ExpressionColumnLoader(_ColumnLoader):
 
         memoized_populators[self.parent_property] = fetch
 
+        # if the column being loaded is the polymorphic discriminator,
+        # and we have a with_expression() providing the actual column,
+        # update the query_entity to use the actual column instead of
+        # the default expression
+        if (
+            query_entity._polymorphic_discriminator is self.columns[0]
+            and loadopt
+            and loadopt._extra_criteria
+        ):
+            query_entity._polymorphic_discriminator = columns[0]
+
     def create_row_processor(
         self,
         context,
index ea8be8d3769da33f9ca194120c67b6c904aec396..f31905092b04c3b3127e76f1280d31e8c1414896 100644 (file)
@@ -28,11 +28,13 @@ from sqlalchemy.orm import joinedload
 from sqlalchemy.orm import Mapped
 from sqlalchemy.orm import mapped_column
 from sqlalchemy.orm import polymorphic_union
+from sqlalchemy.orm import query_expression
 from sqlalchemy.orm import relationship
 from sqlalchemy.orm import selectinload
 from sqlalchemy.orm import Session
 from sqlalchemy.orm import sessionmaker
 from sqlalchemy.orm import subqueryload
+from sqlalchemy.orm import with_expression
 from sqlalchemy.orm import with_polymorphic
 from sqlalchemy.orm.interfaces import MANYTOONE
 from sqlalchemy.testing import AssertsCompiledSQL
@@ -3326,3 +3328,136 @@ class SubclassWithPolyEagerLoadTest(fixtures.DeclarativeMappedTest):
 
                 # as well as the collection eagerly loaded
                 assert obj.bs
+
+
+class PolymorphicQueryExpressionTest(fixtures.DeclarativeMappedTest):
+    """Test for issue #12631 - with_expression() can apply the
+    polymorphic discriminator"""
+
+    @classmethod
+    def setup_classes(cls):
+        Base = cls.DeclarativeBasic
+
+        class Service(Base):
+            __tablename__ = "service"
+            id: Mapped[int] = mapped_column(primary_key=True)
+            name: Mapped[str]
+            type: Mapped[str]
+
+        class BaseRequest(Base):
+            __tablename__ = "base_request"
+            id: Mapped[int] = mapped_column(primary_key=True)
+            service_id: Mapped[int] = mapped_column(ForeignKey(Service.id))
+            service: Mapped[Service] = relationship()
+            type: Mapped[str] = query_expression(
+                select(Service.type)
+                .where(Service.id == service_id)
+                .correlate_except(Service)
+                .scalar_subquery()
+            )
+            __mapper_args__ = {
+                "polymorphic_on": type,
+                "polymorphic_abstract": True,
+            }
+
+        class RequestA(BaseRequest):
+            __tablename__ = "request_a"
+            id: Mapped[int] = mapped_column(
+                ForeignKey(BaseRequest.id), primary_key=True
+            )
+            __mapper_args__ = {"polymorphic_identity": "SERVICE_A"}
+
+        class RequestB(BaseRequest):
+            __tablename__ = "request_b"
+            id: Mapped[int] = mapped_column(
+                ForeignKey(BaseRequest.id), primary_key=True
+            )
+            __mapper_args__ = {"polymorphic_identity": "SERVICE_B"}
+
+    @classmethod
+    def insert_data(cls, connection):
+        Service, RequestA, RequestB = cls.classes(
+            "Service", "RequestA", "RequestB"
+        )
+
+        with Session(connection) as sess:
+            sess.add(
+                RequestA(service=Service(name="test-a", type="SERVICE_A"))
+            )
+            sess.add(
+                RequestB(service=Service(name="test-b", type="SERVICE_B"))
+            )
+            sess.commit()
+
+    def test_with_expression_polymorphic_discriminator(self):
+        """Test that with_expression() can provide the polymorphic
+        discriminator value."""
+        BaseRequest, Service = self.classes("BaseRequest", "Service")
+
+        session = fixture_session()
+
+        result = list(
+            session.scalars(
+                select(BaseRequest)
+                .join(BaseRequest.service)
+                .options(
+                    contains_eager(BaseRequest.service),
+                    with_expression(BaseRequest.type, Service.type),
+                )
+            )
+        )
+
+        # Should get correct polymorphic subclasses
+        eq_(len(result), 2)
+        assert any(r.__class__.__name__ == "RequestA" for r in result)
+        assert any(r.__class__.__name__ == "RequestB" for r in result)
+
+    def test_default_query_expression_loads(self):
+        """Test that the default query_expression loads correctly
+        even though it's the polymorphic discriminator."""
+        BaseRequest, Service = self.classes("BaseRequest", "Service")
+
+        session = fixture_session()
+
+        # Query without with_expression() - should use the default
+        # scalar subquery expression
+        result = list(session.scalars(select(BaseRequest)))
+
+        # Should get correct polymorphic subclasses using default expression
+        eq_(len(result), 2)
+        assert any(r.__class__.__name__ == "RequestA" for r in result)
+        assert any(r.__class__.__name__ == "RequestB" for r in result)
+
+        # The type attribute should be loaded with the correct values
+        for r in result:
+            if r.__class__.__name__ == "RequestA":
+                eq_(r.type, "SERVICE_A")
+            else:
+                eq_(r.type, "SERVICE_B")
+
+    def test_with_polymorphic_workaround(self):
+        """Test that the workaround using explicit with_polymorphic
+        continues to work."""
+        BaseRequest, Service = self.classes("BaseRequest", "Service")
+
+        session = fixture_session()
+
+        bwp = with_polymorphic(
+            BaseRequest, [], polymorphic_on=Service.__table__.c.type
+        )
+
+        result = list(
+            session.scalars(
+                select(bwp)
+                .join(bwp.service)
+                .options(
+                    contains_eager(bwp.service),
+                    with_expression(bwp.type, Service.type),
+                )
+            )
+        )
+
+        # Should get correct polymorphic subclasses
+        eq_(len(result), 2)
+        assert any(r.__class__.__name__ == "RequestA" for r in result)
+        assert any(r.__class__.__name__ == "RequestB" for r in result)