From: Mike Bayer Date: Thu, 18 Dec 2025 16:15:46 +0000 (-0500) Subject: handle polymorphic_discriminator in query_expression X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=4a8d7f2dd7101265d8bb8dc90f18126df7a16ccd;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git handle polymorphic_discriminator in query_expression Added support for using :func:`_orm.with_expression` to populate a :func:`_orm.query_expression` attribute that is also configured as the ``polymorphic_on`` discriminator column. The ORM now detects when a query expression column is serving as the polymorphic discriminator and updates it to use the column provided via :func:`_orm.with_expression`, enabling polymorphic loading to work correctly in this scenario. This allows for patterns such as where the discriminator value is computed from a related table. Fixes: #12631 Change-Id: I20baf4cddc5a19664bf73764f9371b187686af68 --- diff --git a/doc/build/changelog/unreleased_21/12631.rst b/doc/build/changelog/unreleased_21/12631.rst new file mode 100644 index 0000000000..abdf6b7707 --- /dev/null +++ b/doc/build/changelog/unreleased_21/12631.rst @@ -0,0 +1,12 @@ +.. change:: + :tags: usecase, orm + :tickets: 12631 + + Added support for using :func:`_orm.with_expression` to populate a + :func:`_orm.query_expression` attribute that is also configured as the + ``polymorphic_on`` discriminator column. The ORM now detects when a query + expression column is serving as the polymorphic discriminator and updates + it to use the column provided via :func:`_orm.with_expression`, enabling + polymorphic loading to work correctly in this scenario. This allows for + patterns such as where the discriminator value is computed from a related + table. diff --git a/lib/sqlalchemy/orm/strategies.py b/lib/sqlalchemy/orm/strategies.py index 6a71316646..e636ef7dd5 100644 --- a/lib/sqlalchemy/orm/strategies.py +++ b/lib/sqlalchemy/orm/strategies.py @@ -338,6 +338,17 @@ class _ExpressionColumnLoader(_ColumnLoader): memoized_populators[self.parent_property] = fetch + # if the column being loaded is the polymorphic discriminator, + # and we have a with_expression() providing the actual column, + # update the query_entity to use the actual column instead of + # the default expression + if ( + query_entity._polymorphic_discriminator is self.columns[0] + and loadopt + and loadopt._extra_criteria + ): + query_entity._polymorphic_discriminator = columns[0] + def create_row_processor( self, context, diff --git a/test/orm/inheritance/test_assorted_poly.py b/test/orm/inheritance/test_assorted_poly.py index ea8be8d376..f31905092b 100644 --- a/test/orm/inheritance/test_assorted_poly.py +++ b/test/orm/inheritance/test_assorted_poly.py @@ -28,11 +28,13 @@ from sqlalchemy.orm import joinedload from sqlalchemy.orm import Mapped from sqlalchemy.orm import mapped_column from sqlalchemy.orm import polymorphic_union +from sqlalchemy.orm import query_expression from sqlalchemy.orm import relationship from sqlalchemy.orm import selectinload from sqlalchemy.orm import Session from sqlalchemy.orm import sessionmaker from sqlalchemy.orm import subqueryload +from sqlalchemy.orm import with_expression from sqlalchemy.orm import with_polymorphic from sqlalchemy.orm.interfaces import MANYTOONE from sqlalchemy.testing import AssertsCompiledSQL @@ -3326,3 +3328,136 @@ class SubclassWithPolyEagerLoadTest(fixtures.DeclarativeMappedTest): # as well as the collection eagerly loaded assert obj.bs + + +class PolymorphicQueryExpressionTest(fixtures.DeclarativeMappedTest): + """Test for issue #12631 - with_expression() can apply the + polymorphic discriminator""" + + @classmethod + def setup_classes(cls): + Base = cls.DeclarativeBasic + + class Service(Base): + __tablename__ = "service" + id: Mapped[int] = mapped_column(primary_key=True) + name: Mapped[str] + type: Mapped[str] + + class BaseRequest(Base): + __tablename__ = "base_request" + id: Mapped[int] = mapped_column(primary_key=True) + service_id: Mapped[int] = mapped_column(ForeignKey(Service.id)) + service: Mapped[Service] = relationship() + type: Mapped[str] = query_expression( + select(Service.type) + .where(Service.id == service_id) + .correlate_except(Service) + .scalar_subquery() + ) + __mapper_args__ = { + "polymorphic_on": type, + "polymorphic_abstract": True, + } + + class RequestA(BaseRequest): + __tablename__ = "request_a" + id: Mapped[int] = mapped_column( + ForeignKey(BaseRequest.id), primary_key=True + ) + __mapper_args__ = {"polymorphic_identity": "SERVICE_A"} + + class RequestB(BaseRequest): + __tablename__ = "request_b" + id: Mapped[int] = mapped_column( + ForeignKey(BaseRequest.id), primary_key=True + ) + __mapper_args__ = {"polymorphic_identity": "SERVICE_B"} + + @classmethod + def insert_data(cls, connection): + Service, RequestA, RequestB = cls.classes( + "Service", "RequestA", "RequestB" + ) + + with Session(connection) as sess: + sess.add( + RequestA(service=Service(name="test-a", type="SERVICE_A")) + ) + sess.add( + RequestB(service=Service(name="test-b", type="SERVICE_B")) + ) + sess.commit() + + def test_with_expression_polymorphic_discriminator(self): + """Test that with_expression() can provide the polymorphic + discriminator value.""" + BaseRequest, Service = self.classes("BaseRequest", "Service") + + session = fixture_session() + + result = list( + session.scalars( + select(BaseRequest) + .join(BaseRequest.service) + .options( + contains_eager(BaseRequest.service), + with_expression(BaseRequest.type, Service.type), + ) + ) + ) + + # Should get correct polymorphic subclasses + eq_(len(result), 2) + assert any(r.__class__.__name__ == "RequestA" for r in result) + assert any(r.__class__.__name__ == "RequestB" for r in result) + + def test_default_query_expression_loads(self): + """Test that the default query_expression loads correctly + even though it's the polymorphic discriminator.""" + BaseRequest, Service = self.classes("BaseRequest", "Service") + + session = fixture_session() + + # Query without with_expression() - should use the default + # scalar subquery expression + result = list(session.scalars(select(BaseRequest))) + + # Should get correct polymorphic subclasses using default expression + eq_(len(result), 2) + assert any(r.__class__.__name__ == "RequestA" for r in result) + assert any(r.__class__.__name__ == "RequestB" for r in result) + + # The type attribute should be loaded with the correct values + for r in result: + if r.__class__.__name__ == "RequestA": + eq_(r.type, "SERVICE_A") + else: + eq_(r.type, "SERVICE_B") + + def test_with_polymorphic_workaround(self): + """Test that the workaround using explicit with_polymorphic + continues to work.""" + BaseRequest, Service = self.classes("BaseRequest", "Service") + + session = fixture_session() + + bwp = with_polymorphic( + BaseRequest, [], polymorphic_on=Service.__table__.c.type + ) + + result = list( + session.scalars( + select(bwp) + .join(bwp.service) + .options( + contains_eager(bwp.service), + with_expression(bwp.type, Service.type), + ) + ) + ) + + # Should get correct polymorphic subclasses + eq_(len(result), 2) + assert any(r.__class__.__name__ == "RequestA" for r in result) + assert any(r.__class__.__name__ == "RequestB" for r in result)