]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
safe annotate QueryableAttribute inside of join() condition
authorMike Bayer <mike_mp@zzzcomputing.com>
Thu, 10 Aug 2023 22:26:45 +0000 (18:26 -0400)
committermike bayer <mike_mp@zzzcomputing.com>
Fri, 11 Aug 2023 22:23:04 +0000 (22:23 +0000)
Fixed fundamental issue which prevented some forms of ORM "annotations"
from taking place for subqueries which made use of :meth:`_sql.Select.join`
against a relationship target.  These annotations are used whenever a
subquery is used in special situations such as within
:meth:`_orm.PropComparator.and_` and other ORM-specific scenarios.

Fixes: #10223
Change-Id: I40f04265a6caa0fdcbc9f1b121a35561ab4b1fcf

doc/build/changelog/unreleased_14/10223.rst [new file with mode: 0644]
lib/sqlalchemy/orm/relationships.py
lib/sqlalchemy/sql/annotation.py
test/orm/test_rel_fn.py
test/orm/test_relationship_criteria.py

diff --git a/doc/build/changelog/unreleased_14/10223.rst b/doc/build/changelog/unreleased_14/10223.rst
new file mode 100644 (file)
index 0000000..7c74424
--- /dev/null
@@ -0,0 +1,10 @@
+.. change::
+    :tags: bug, orm
+    :tickets: 10223
+    :versions: 2.0.20
+
+    Fixed fundamental issue which prevented some forms of ORM "annotations"
+    from taking place for subqueries which made use of :meth:`_sql.Select.join`
+    against a relationship target.  These annotations are used whenever a
+    subquery is used in special situations such as within
+    :meth:`_orm.PropComparator.and_` and other ORM-specific scenarios.
index e2ba6989266cd6e55747ade73962cf42f7e84fd9..d3a8da042a4a953c2fec8ad4606e48798d4ff4e4 100644 (file)
@@ -78,6 +78,7 @@ from ..sql import roles
 from ..sql import visitors
 from ..sql._typing import _ColumnExpressionArgument
 from ..sql._typing import _HasClauseElement
+from ..sql.annotation import _safe_annotate
 from ..sql.elements import ColumnClause
 from ..sql.elements import ColumnElement
 from ..sql.util import _deep_annotate
@@ -3297,7 +3298,7 @@ class JoinCondition:
                     parentmapper_for_element is not self.prop.parent
                     and parentmapper_for_element is not self.prop.mapper
                 ):
-                    return elem._annotate(annotations)
+                    return _safe_annotate(elem, annotations)
                 else:
                     return elem
 
index 016608a38168b1a9a48e1613e30bb0cd601c3df3..4ccde591a9ac92b979b97578c19818e29daf0ec0 100644 (file)
@@ -402,6 +402,18 @@ annotated_classes: Dict[
 _SA = TypeVar("_SA", bound="SupportsAnnotations")
 
 
+def _safe_annotate(to_annotate: _SA, annotations: _AnnotationDict) -> _SA:
+    try:
+        _annotate = to_annotate._annotate
+    except AttributeError:
+        # skip objects that don't actually have an `_annotate`
+        # attribute, namely QueryableAttribute inside of a join
+        # condition
+        return to_annotate
+    else:
+        return _annotate(annotations)
+
+
 def _deep_annotate(
     element: _SA,
     annotations: _AnnotationDict,
@@ -455,7 +467,7 @@ def _deep_annotate(
             if annotate_callable:
                 newelem = annotate_callable(to_annotate, annotations)
             else:
-                newelem = to_annotate._annotate(annotations)
+                newelem = _safe_annotate(to_annotate, annotations)
         else:
             newelem = elem
 
index 5b2a15c13a541901e7ed23cfa64dc492b6c05e1b..83ffff3c91bdbec43660a1203b062614a7f3f72a 100644 (file)
@@ -1309,6 +1309,28 @@ class LazyClauseTest(_JoinFixtures, fixtures.TestBase, AssertsCompiledSQL):
 
 
 class DeannotateCorrectlyTest(fixtures.TestBase):
+    def test_annotate_orm_join(self):
+        """test for #10223"""
+        from sqlalchemy.orm import declarative_base
+
+        Base = declarative_base()
+
+        class A(Base):
+            __tablename__ = "a"
+            id = Column(Integer, primary_key=True)
+            bs = relationship("B")
+
+        class B(Base):
+            __tablename__ = "b"
+            id = Column(Integer, primary_key=True)
+            a_id = Column(ForeignKey(A.id))
+
+        stmt = select(A).join(A.bs)
+
+        from sqlalchemy.sql import util
+
+        util._deep_annotate(stmt, {"foo": "bar"})
+
     def test_pj_deannotates(self):
         from sqlalchemy.orm import declarative_base
 
index d03b79e8920c623cd2f2a49a6da18c51e53f9f1c..31423e5f4a676dccd48e80a4bc03b47664ba9bbf 100644 (file)
@@ -1,5 +1,8 @@
+from __future__ import annotations
+
 import datetime
 import random
+from typing import List
 
 from sqlalchemy import Column
 from sqlalchemy import DateTime
@@ -15,15 +18,19 @@ from sqlalchemy import orm
 from sqlalchemy import select
 from sqlalchemy import sql
 from sqlalchemy import String
+from sqlalchemy import Table
 from sqlalchemy import testing
 from sqlalchemy import union
 from sqlalchemy import update
 from sqlalchemy.orm import aliased
 from sqlalchemy.orm import column_property
+from sqlalchemy.orm import contains_eager
 from sqlalchemy.orm import defer
 from sqlalchemy.orm import join as orm_join
 from sqlalchemy.orm import joinedload
 from sqlalchemy.orm import lazyload
+from sqlalchemy.orm import Mapped
+from sqlalchemy.orm import mapped_column
 from sqlalchemy.orm import registry
 from sqlalchemy.orm import relationship
 from sqlalchemy.orm import selectinload
@@ -33,6 +40,7 @@ from sqlalchemy.orm import with_loader_criteria
 from sqlalchemy.orm.decl_api import declared_attr
 from sqlalchemy.testing import eq_
 from sqlalchemy.testing import expect_raises_message
+from sqlalchemy.testing import fixtures
 from sqlalchemy.testing.assertions import expect_raises
 from sqlalchemy.testing.assertsql import CompiledSQL
 from sqlalchemy.testing.fixtures import fixture_session
@@ -2365,3 +2373,176 @@ class RelationshipCriteriaTest(_Fixtures, testing.AssertsCompiledSQL):
             "JOIN items AS items_1 ON items_1.id = order_items_1.item_id "
             "AND items_1.description != :description_1",
         )
+
+
+class SubqueryCriteriaTest(fixtures.DeclarativeMappedTest):
+    """test #10223"""
+
+    @classmethod
+    def setup_classes(cls):
+        Base = cls.DeclarativeBasic
+
+        class Temperature(Base):
+            __tablename__ = "temperature"
+            id: Mapped[int] = mapped_column(primary_key=True)
+            pointless_flag: Mapped[bool]
+
+        class Color(Base):
+            __tablename__ = "color"
+            id: Mapped[int] = mapped_column(primary_key=True)
+            name: Mapped[str] = mapped_column(unique=True)
+            temperature_id: Mapped[int] = mapped_column(
+                ForeignKey("temperature.id")
+            )
+            temperature: Mapped[Temperature] = relationship()
+
+        room_connections = Table(
+            "room_connections",
+            Base.metadata,
+            Column(
+                "room_a_id",
+                Integer,
+                # mariadb does not like this FK constraint
+                # ForeignKey("room.id"),
+                primary_key=True,
+            ),
+            Column(
+                "room_b_id",
+                Integer,
+                # mariadb does not like this FK constraint
+                # ForeignKey("room.id"),
+                primary_key=True,
+            ),
+        )
+
+        class Room(Base):
+            __tablename__ = "room"
+            id: Mapped[int] = mapped_column(primary_key=True)
+            token: Mapped[str] = mapped_column(unique=True)
+            color_id: Mapped[int] = mapped_column(ForeignKey("color.id"))
+            color: Mapped[Color] = relationship()
+            connected_rooms: Mapped[List["Room"]] = relationship(  # noqa: F821
+                secondary=room_connections,
+                primaryjoin=id == room_connections.c.room_a_id,
+                secondaryjoin=id == room_connections.c.room_b_id,
+            )
+
+    @classmethod
+    def insert_data(cls, connection):
+        Room, Temperature, Color = cls.classes("Room", "Temperature", "Color")
+        with Session(connection) as session:
+            warm = Temperature(pointless_flag=True)
+            cool = Temperature(pointless_flag=True)
+            session.add_all([warm, cool])
+
+            red = Color(name="red", temperature=warm)
+            orange = Color(name="orange", temperature=warm)
+            blue = Color(name="blue", temperature=cool)
+            green = Color(name="green", temperature=cool)
+            session.add_all([red, orange, blue, green])
+
+            red1 = Room(token="Red-1", color=red)
+            red2 = Room(token="Red-2", color=red)
+            orange2 = Room(token="Orange-2", color=orange)
+            blue1 = Room(token="Blue-1", color=blue)
+            blue2 = Room(token="Blue-2", color=blue)
+            green1 = Room(token="Green-1", color=green)
+            red1.connected_rooms = [red2, blue1, green1]
+            red2.connected_rooms = [red1, blue2, orange2]
+            blue1.connected_rooms = [red1, blue2, green1]
+            blue2.connected_rooms = [red2, blue1, orange2]
+            session.add_all([red1, red2, blue1, blue2, green1, orange2])
+
+            session.commit()
+
+    @testing.variation(
+        "join_on_relationship", ["alone", "with_and", "no", "omit"]
+    )
+    def test_selectinload(self, join_on_relationship):
+        Room, Temperature, Color = self.classes("Room", "Temperature", "Color")
+        similar_color = aliased(Color)
+        subquery = (
+            select(Color.id)
+            .join(
+                similar_color,
+                similar_color.temperature_id == Color.temperature_id,
+            )
+            .where(similar_color.name == "red")
+        )
+
+        if join_on_relationship.alone:
+            subquery = subquery.join(Color.temperature).where(
+                Temperature.pointless_flag == True
+            )
+        elif join_on_relationship.with_and:
+            subquery = subquery.join(
+                Color.temperature.and_(Temperature.pointless_flag == True)
+            )
+        elif join_on_relationship.no:
+            subquery = subquery.join(
+                Temperature, Color.temperature_id == Temperature.id
+            ).where(Temperature.pointless_flag == True)
+        elif join_on_relationship.omit:
+            pass
+        else:
+            join_on_relationship.fail()
+
+        session = fixture_session()
+        room_result = session.scalars(
+            select(Room)
+            .order_by(Room.id)
+            .join(Room.color.and_(Color.name == "red"))
+            .options(
+                selectinload(
+                    Room.connected_rooms.and_(Room.color_id.in_(subquery))
+                )
+            )
+        ).unique()
+
+        self._assert_result(room_result)
+
+    def test_contains_eager(self):
+        Room, Temperature, Color = self.classes("Room", "Temperature", "Color")
+        similar_color = aliased(Color)
+        subquery = (
+            select(Color.id)
+            .join(
+                similar_color,
+                similar_color.temperature_id == Color.temperature_id,
+            )
+            .join(Color.temperature.and_(Temperature.pointless_flag == True))
+            .where(similar_color.name == "red")
+        )
+
+        room_alias = aliased(Room)
+        session = fixture_session()
+
+        room_result = session.scalars(
+            select(Room)
+            .order_by(Room.id)
+            .join(Room.color.and_(Color.name == "red"))
+            .join(
+                room_alias,
+                Room.connected_rooms.of_type(room_alias).and_(
+                    room_alias.color_id.in_(subquery)
+                ),
+            )
+            .options(contains_eager(Room.connected_rooms.of_type(room_alias)))
+        ).unique()
+
+        self._assert_result(room_result)
+
+    def _assert_result(self, room_result):
+        eq_(
+            [
+                (
+                    each_room.token,
+                    [room.token for room in each_room.connected_rooms],
+                )
+                for each_room in room_result
+            ],
+            [
+                ("Red-1", ["Red-2"]),
+                ("Red-2", ["Red-1", "Orange-2"]),
+            ],
+        )