from sqlalchemy.orm import Mapped
from sqlalchemy.orm import mapped_column
from sqlalchemy.orm import polymorphic_union
+from sqlalchemy.orm import query_expression
from sqlalchemy.orm import relationship
from sqlalchemy.orm import selectinload
from sqlalchemy.orm import Session
from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm import subqueryload
+from sqlalchemy.orm import with_expression
from sqlalchemy.orm import with_polymorphic
from sqlalchemy.orm.interfaces import MANYTOONE
from sqlalchemy.testing import AssertsCompiledSQL
# as well as the collection eagerly loaded
assert obj.bs
+
+
+class PolymorphicQueryExpressionTest(fixtures.DeclarativeMappedTest):
+ """Test for issue #12631 - with_expression() can apply the
+ polymorphic discriminator"""
+
+ @classmethod
+ def setup_classes(cls):
+ Base = cls.DeclarativeBasic
+
+ class Service(Base):
+ __tablename__ = "service"
+ id: Mapped[int] = mapped_column(primary_key=True)
+ name: Mapped[str]
+ type: Mapped[str]
+
+ class BaseRequest(Base):
+ __tablename__ = "base_request"
+ id: Mapped[int] = mapped_column(primary_key=True)
+ service_id: Mapped[int] = mapped_column(ForeignKey(Service.id))
+ service: Mapped[Service] = relationship()
+ type: Mapped[str] = query_expression(
+ select(Service.type)
+ .where(Service.id == service_id)
+ .correlate_except(Service)
+ .scalar_subquery()
+ )
+ __mapper_args__ = {
+ "polymorphic_on": type,
+ "polymorphic_abstract": True,
+ }
+
+ class RequestA(BaseRequest):
+ __tablename__ = "request_a"
+ id: Mapped[int] = mapped_column(
+ ForeignKey(BaseRequest.id), primary_key=True
+ )
+ __mapper_args__ = {"polymorphic_identity": "SERVICE_A"}
+
+ class RequestB(BaseRequest):
+ __tablename__ = "request_b"
+ id: Mapped[int] = mapped_column(
+ ForeignKey(BaseRequest.id), primary_key=True
+ )
+ __mapper_args__ = {"polymorphic_identity": "SERVICE_B"}
+
+ @classmethod
+ def insert_data(cls, connection):
+ Service, RequestA, RequestB = cls.classes(
+ "Service", "RequestA", "RequestB"
+ )
+
+ with Session(connection) as sess:
+ sess.add(
+ RequestA(service=Service(name="test-a", type="SERVICE_A"))
+ )
+ sess.add(
+ RequestB(service=Service(name="test-b", type="SERVICE_B"))
+ )
+ sess.commit()
+
+ def test_with_expression_polymorphic_discriminator(self):
+ """Test that with_expression() can provide the polymorphic
+ discriminator value."""
+ BaseRequest, Service = self.classes("BaseRequest", "Service")
+
+ session = fixture_session()
+
+ result = list(
+ session.scalars(
+ select(BaseRequest)
+ .join(BaseRequest.service)
+ .options(
+ contains_eager(BaseRequest.service),
+ with_expression(BaseRequest.type, Service.type),
+ )
+ )
+ )
+
+ # Should get correct polymorphic subclasses
+ eq_(len(result), 2)
+ assert any(r.__class__.__name__ == "RequestA" for r in result)
+ assert any(r.__class__.__name__ == "RequestB" for r in result)
+
+ def test_default_query_expression_loads(self):
+ """Test that the default query_expression loads correctly
+ even though it's the polymorphic discriminator."""
+ BaseRequest, Service = self.classes("BaseRequest", "Service")
+
+ session = fixture_session()
+
+ # Query without with_expression() - should use the default
+ # scalar subquery expression
+ result = list(session.scalars(select(BaseRequest)))
+
+ # Should get correct polymorphic subclasses using default expression
+ eq_(len(result), 2)
+ assert any(r.__class__.__name__ == "RequestA" for r in result)
+ assert any(r.__class__.__name__ == "RequestB" for r in result)
+
+ # The type attribute should be loaded with the correct values
+ for r in result:
+ if r.__class__.__name__ == "RequestA":
+ eq_(r.type, "SERVICE_A")
+ else:
+ eq_(r.type, "SERVICE_B")
+
+ def test_with_polymorphic_workaround(self):
+ """Test that the workaround using explicit with_polymorphic
+ continues to work."""
+ BaseRequest, Service = self.classes("BaseRequest", "Service")
+
+ session = fixture_session()
+
+ bwp = with_polymorphic(
+ BaseRequest, [], polymorphic_on=Service.__table__.c.type
+ )
+
+ result = list(
+ session.scalars(
+ select(bwp)
+ .join(bwp.service)
+ .options(
+ contains_eager(bwp.service),
+ with_expression(bwp.type, Service.type),
+ )
+ )
+ )
+
+ # Should get correct polymorphic subclasses
+ eq_(len(result), 2)
+ assert any(r.__class__.__name__ == "RequestA" for r in result)
+ assert any(r.__class__.__name__ == "RequestB" for r in result)