]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
limit joinedload exclusion rules to immediate mapped columns
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 13 May 2023 16:32:31 +0000 (12:32 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 14 May 2023 13:50:45 +0000 (09:50 -0400)
Fixed issue where using additional relationship criteria with the
:func:`_orm.joinedload` loader option, where the additional criteria itself
contained correlated subqueries that referred to the joined entities and
therefore also required "adaption" to aliased entities, would be excluded
from this adaption, producing the wrong ON clause for the joinedload.

Fixes: #9779
Change-Id: Idcfec3e760057fbf6a09c10ad67a0bb4bf70f03a

doc/build/changelog/unreleased_20/9779.rst [new file with mode: 0644]
lib/sqlalchemy/orm/relationships.py
lib/sqlalchemy/sql/annotation.py
lib/sqlalchemy/sql/util.py
test/orm/inheritance/test_assorted_poly.py

diff --git a/doc/build/changelog/unreleased_20/9779.rst b/doc/build/changelog/unreleased_20/9779.rst
new file mode 100644 (file)
index 0000000..ab417b2
--- /dev/null
@@ -0,0 +1,9 @@
+.. change::
+    :tags: bug, orm
+    :tickets: 9779
+
+    Fixed issue where using additional relationship criteria with the
+    :func:`_orm.joinedload` loader option, where the additional criteria itself
+    contained correlated subqueries that referred to the joined entities and
+    therefore also required "adaption" to aliased entities, would be excluded
+    from this adaption, producing the wrong ON clause for the joinedload.
index 8d9f3c644ac65f05ee45985c246e8d981cd72bef..e5a6b9afaa4adb8807bbe1a8f5a2e19b12dc2d28 100644 (file)
@@ -80,6 +80,7 @@ from ..sql._typing import _ColumnExpressionArgument
 from ..sql._typing import _HasClauseElement
 from ..sql.elements import ColumnClause
 from ..sql.elements import ColumnElement
+from ..sql.util import _deep_annotate
 from ..sql.util import _deep_deannotate
 from ..sql.util import _shallow_annotate
 from ..sql.util import adapt_criterion_to_null
@@ -115,6 +116,7 @@ if typing.TYPE_CHECKING:
     from ..sql._typing import _EquivalentColumnMap
     from ..sql._typing import _InfoType
     from ..sql.annotation import _AnnotationDict
+    from ..sql.annotation import SupportsAnnotations
     from ..sql.elements import BinaryExpression
     from ..sql.elements import BindParameter
     from ..sql.elements import ClauseElement
@@ -3284,6 +3286,38 @@ class JoinCondition:
                 primaryjoin = primaryjoin & single_crit
 
         if extra_criteria:
+
+            def mark_unrelated_columns_as_ok_to_adapt(
+                elem: SupportsAnnotations, annotations: _AnnotationDict
+            ) -> SupportsAnnotations:
+                """note unrelated columns in the "extra criteria" as OK
+                to adapt, even though they are not part of our "local"
+                or "remote" side.
+
+                see #9779 for this case
+
+                """
+
+                parentmapper_for_element = elem._annotations.get(
+                    "parentmapper", None
+                )
+                if (
+                    parentmapper_for_element is not self.prop.parent
+                    and parentmapper_for_element is not self.prop.mapper
+                ):
+                    return elem._annotate(annotations)
+                else:
+                    return elem
+
+            extra_criteria = tuple(
+                _deep_annotate(
+                    elem,
+                    {"ok_to_adapt_in_join_condition": True},
+                    annotate_callable=mark_unrelated_columns_as_ok_to_adapt,
+                )
+                for elem in extra_criteria
+            )
+
             if secondaryjoin is not None:
                 secondaryjoin = secondaryjoin & sql.and_(*extra_criteria)
             else:
@@ -3409,7 +3443,10 @@ class _ColInAnnotations:
         self.name = name
 
     def __call__(self, c: ClauseElement) -> bool:
-        return self.name in c._annotations
+        return (
+            self.name in c._annotations
+            or "ok_to_adapt_in_join_condition" in c._annotations
+        )
 
 
 class Relationship(  # type: ignore
index 7487e074ccccc0b0e74f976084513e4a9f4554e3..e6dee7d17e3c680cabda49e864993b2915ba4ddf 100644 (file)
@@ -406,8 +406,12 @@ def _deep_annotate(
     element: _SA,
     annotations: _AnnotationDict,
     exclude: Optional[Sequence[SupportsAnnotations]] = None,
+    *,
     detect_subquery_cols: bool = False,
     ind_cols_on_fromclause: bool = False,
+    annotate_callable: Optional[
+        Callable[[SupportsAnnotations, _AnnotationDict], SupportsAnnotations]
+    ] = None,
 ) -> _SA:
     """Deep copy the given ClauseElement, annotating each element
     with the given annotations dictionary.
@@ -446,9 +450,13 @@ def _deep_annotate(
             newelem = elem._clone(clone=clone, **kw)
         elif annotations != elem._annotations:
             if detect_subquery_cols and elem._is_immutable:
-                newelem = elem._clone(clone=clone, **kw)._annotate(annotations)
+                to_annotate = elem._clone(clone=clone, **kw)
             else:
-                newelem = elem._annotate(annotations)
+                to_annotate = elem
+            if annotate_callable:
+                newelem = annotate_callable(to_annotate, annotations)
+            else:
+                newelem = to_annotate._annotate(annotations)
         else:
             newelem = elem
 
index 0a50197a0d4a4493ee5dbed5b2d7f239dd878644..18caf5de4b5edec6bb73e4f9829c00ce1b043509 100644 (file)
@@ -1343,6 +1343,7 @@ class ColumnAdapter(ClauseAdapter):
     def traverse(
         self, obj: Optional[ExternallyTraversible]
     ) -> Optional[ExternallyTraversible]:
+
         return self.columns[obj]
 
     def chain(self, visitor: ExternalTraversal) -> ColumnAdapter:
index a40a9ae742c6b438f8f92e3469871033975b4396..28480c89ed559258e090878c230e848f2ad7cd3a 100644 (file)
@@ -7,6 +7,7 @@ from __future__ import annotations
 
 from typing import Optional
 
+from sqlalchemy import and_
 from sqlalchemy import exists
 from sqlalchemy import ForeignKey
 from sqlalchemy import func
@@ -37,6 +38,7 @@ from sqlalchemy.testing import AssertsCompiledSQL
 from sqlalchemy.testing import AssertsExecutionResults
 from sqlalchemy.testing import config
 from sqlalchemy.testing import eq_
+from sqlalchemy.testing import expect_warnings
 from sqlalchemy.testing import fixtures
 from sqlalchemy.testing.fixtures import ComparableEntity
 from sqlalchemy.testing.fixtures import fixture_session
@@ -2765,3 +2767,169 @@ class PolyIntoSelfReferentialTest(
                     assert False
 
         self._run_load(opt)
+
+
+class AdaptExistsSubqTest(fixtures.DeclarativeMappedTest):
+    """test for #9777"""
+
+    @classmethod
+    def setup_classes(cls):
+        Base = cls.DeclarativeBasic
+
+        class Discriminator(Base):
+            __tablename__ = "discriminator"
+            id = Column(Integer, primary_key=True, autoincrement=False)
+            value = Column(String(50))
+
+        class Entity(Base):
+            __tablename__ = "entity"
+            __mapper_args__ = {"polymorphic_on": "type"}
+
+            id = Column(Integer, primary_key=True, autoincrement=False)
+            type = Column(String(50))
+
+            discriminator_id = Column(
+                ForeignKey("discriminator.id"), nullable=False
+            )
+            discriminator = relationship(
+                "Discriminator", foreign_keys=discriminator_id
+            )
+
+        class Parent(Entity):
+            __tablename__ = "parent"
+            __mapper_args__ = {"polymorphic_identity": "parent"}
+
+            id = Column(Integer, ForeignKey("entity.id"), primary_key=True)
+            some_data = Column(String(30))
+
+        class Child(Entity):
+            __tablename__ = "child"
+            __mapper_args__ = {"polymorphic_identity": "child"}
+
+            id = Column(Integer, ForeignKey("entity.id"), primary_key=True)
+
+            some_data = Column(String(30))
+            parent_id = Column(ForeignKey("parent.id"), nullable=False)
+            parent = relationship(
+                "Parent",
+                foreign_keys=parent_id,
+                backref="children",
+            )
+
+    @classmethod
+    def insert_data(cls, connection):
+        Parent, Child, Discriminator = cls.classes(
+            "Parent", "Child", "Discriminator"
+        )
+
+        with Session(connection) as sess:
+            discriminator_zero = Discriminator(id=1, value="zero")
+            discriminator_one = Discriminator(id=2, value="one")
+            discriminator_two = Discriminator(id=3, value="two")
+
+            parent = Parent(id=1, discriminator=discriminator_zero)
+            child_1 = Child(
+                id=2,
+                discriminator=discriminator_one,
+                parent=parent,
+                some_data="c1data",
+            )
+            child_2 = Child(
+                id=3,
+                discriminator=discriminator_two,
+                parent=parent,
+                some_data="c2data",
+            )
+            sess.add_all([parent, child_1, child_2])
+            sess.commit()
+
+    def test_explicit_aliasing(self):
+        Parent, Child, Discriminator = self.classes(
+            "Parent", "Child", "Discriminator"
+        )
+
+        parent_id = 1
+        discriminator_one_id = 2
+
+        session = fixture_session()
+        c_alias = aliased(Child, flat=True)
+        retrieved = (
+            session.query(Parent)
+            .filter_by(id=parent_id)
+            .outerjoin(
+                Parent.children.of_type(c_alias).and_(
+                    c_alias.discriminator.has(
+                        and_(
+                            Discriminator.id == discriminator_one_id,
+                            c_alias.some_data == "c1data",
+                        )
+                    )
+                )
+            )
+            .options(contains_eager(Parent.children.of_type(c_alias)))
+            .populate_existing()
+            .one()
+        )
+        eq_(len(retrieved.children), 1)
+
+    def test_implicit_aliasing(self):
+        Parent, Child, Discriminator = self.classes(
+            "Parent", "Child", "Discriminator"
+        )
+
+        parent_id = 1
+        discriminator_one_id = 2
+
+        session = fixture_session()
+        q = (
+            session.query(Parent)
+            .filter_by(id=parent_id)
+            .outerjoin(
+                Parent.children.and_(
+                    Child.discriminator.has(
+                        and_(
+                            Discriminator.id == discriminator_one_id,
+                            Child.some_data == "c1data",
+                        )
+                    )
+                )
+            )
+            .options(contains_eager(Parent.children))
+            .populate_existing()
+        )
+
+        with expect_warnings("An alias is being generated automatically"):
+            retrieved = q.one()
+
+        eq_(len(retrieved.children), 1)
+
+    @testing.combinations(joinedload, selectinload, argnames="loader")
+    def test_eager_loaders(self, loader):
+        Parent, Child, Discriminator = self.classes(
+            "Parent", "Child", "Discriminator"
+        )
+
+        parent_id = 1
+        discriminator_one_id = 2
+
+        session = fixture_session()
+        retrieved = (
+            session.query(Parent)
+            .filter_by(id=parent_id)
+            .options(
+                loader(
+                    Parent.children.and_(
+                        Child.discriminator.has(
+                            and_(
+                                Discriminator.id == discriminator_one_id,
+                                Child.some_data == "c1data",
+                            )
+                        )
+                    )
+                )
+            )
+            .populate_existing()
+            .one()
+        )
+
+        eq_(len(retrieved.children), 1)