]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
safe annotate QueryableAttribute inside of join() condition
authorMike Bayer <mike_mp@zzzcomputing.com>
Thu, 10 Aug 2023 22:26:45 +0000 (18:26 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 11 Aug 2023 16:36:21 +0000 (12:36 -0400)
Fixed fundamental issue which prevented some forms of ORM "annotations"
from taking place for subqueries which made use of :meth:`_sql.Select.join`
against a relationship target.  These annotations are used whenever a
subquery is used in special situations such as within
:meth:`_orm.PropComparator.and_` and other ORM-specific scenarios.

Fixes: #10223
Change-Id: I40f04265a6caa0fdcbc9f1b121a35561ab4b1fcf
(cherry picked from commit 6cfdc0743b7d1ebee3582f612a4f8acaa6ab42f9)

doc/build/changelog/unreleased_14/10223.rst [new file with mode: 0644]
lib/sqlalchemy/sql/annotation.py
test/orm/test_rel_fn.py
test/orm/test_relationship_criteria.py

diff --git a/doc/build/changelog/unreleased_14/10223.rst b/doc/build/changelog/unreleased_14/10223.rst
new file mode 100644 (file)
index 0000000..7c74424
--- /dev/null
@@ -0,0 +1,10 @@
+.. change::
+    :tags: bug, orm
+    :tickets: 10223
+    :versions: 2.0.20
+
+    Fixed fundamental issue which prevented some forms of ORM "annotations"
+    from taking place for subqueries which made use of :meth:`_sql.Select.join`
+    against a relationship target.  These annotations are used whenever a
+    subquery is used in special situations such as within
+    :meth:`_orm.PropComparator.and_` and other ORM-specific scenarios.
index 60e600ddf0bfc9338d6c39b5836c76431bc28c2d..f98038d6a21d2441ee0f23da19f6aa5469e4b4ba 100644 (file)
@@ -242,6 +242,18 @@ class Annotated(object):
 annotated_classes = {}
 
 
+def _safe_annotate(to_annotate, annotations):
+    try:
+        _annotate = to_annotate._annotate
+    except AttributeError:
+        # skip objects that don't actually have an `_annotate`
+        # attribute, namely QueryableAttribute inside of a join
+        # condition
+        return to_annotate
+    else:
+        return _annotate(annotations)
+
+
 def _deep_annotate(
     element, annotations, exclude=None, detect_subquery_cols=False
 ):
@@ -272,9 +284,11 @@ def _deep_annotate(
             newelem = elem._clone(clone=clone, **kw)
         elif annotations != elem._annotations:
             if detect_subquery_cols and elem._is_immutable:
-                newelem = elem._clone(clone=clone, **kw)._annotate(annotations)
+                newelem = _safe_annotate(
+                    elem._clone(clone=clone, **kw), annotations
+                )
             else:
-                newelem = elem._annotate(annotations)
+                newelem = _safe_annotate(elem, annotations)
         else:
             newelem = elem
         newelem._copy_internals(clone=clone)
index 4d8eb88b91cd0fc7a735076848cecf259c3f9095..a4e769d445d8b46851f9564e5ac7a6f85ad75031 100644 (file)
@@ -1243,6 +1243,28 @@ class LazyClauseTest(_JoinFixtures, fixtures.TestBase, AssertsCompiledSQL):
 
 
 class DeannotateCorrectlyTest(fixtures.TestBase):
+    def test_annotate_orm_join(self):
+        """test for #10223"""
+        from sqlalchemy.orm import declarative_base
+
+        Base = declarative_base()
+
+        class A(Base):
+            __tablename__ = "a"
+            id = Column(Integer, primary_key=True)
+            bs = relationship("B")
+
+        class B(Base):
+            __tablename__ = "b"
+            id = Column(Integer, primary_key=True)
+            a_id = Column(ForeignKey(A.id))
+
+        stmt = select(A).join(A.bs)
+
+        from sqlalchemy.sql import util
+
+        util._deep_annotate(stmt, {"foo": "bar"})
+
     def test_pj_deannotates(self):
         from sqlalchemy.orm import declarative_base
 
index e866fe018626e64c979d7c71caef7d34c4641977..e1dc0ae29e98b8feaee365ba0a92fece88287e43 100644 (file)
@@ -1,6 +1,7 @@
 import datetime
 import random
 
+from sqlalchemy import Boolean
 from sqlalchemy import Column
 from sqlalchemy import DateTime
 from sqlalchemy import event
@@ -13,9 +14,11 @@ from sqlalchemy import orm
 from sqlalchemy import select
 from sqlalchemy import sql
 from sqlalchemy import String
+from sqlalchemy import Table
 from sqlalchemy import testing
 from sqlalchemy.orm import aliased
 from sqlalchemy.orm import column_property
+from sqlalchemy.orm import contains_eager
 from sqlalchemy.orm import defer
 from sqlalchemy.orm import join as orm_join
 from sqlalchemy.orm import joinedload
@@ -29,6 +32,7 @@ from sqlalchemy.orm import with_loader_criteria
 from sqlalchemy.orm.decl_api import declared_attr
 from sqlalchemy.testing import eq_
 from sqlalchemy.testing import expect_raises_message
+from sqlalchemy.testing import fixtures
 from sqlalchemy.testing.assertions import expect_raises
 from sqlalchemy.testing.assertsql import CompiledSQL
 from sqlalchemy.testing.fixtures import fixture_session
@@ -2137,3 +2141,175 @@ class RelationshipCriteriaTest(_Fixtures, testing.AssertsCompiledSQL):
             "JOIN items AS items_1 ON items_1.id = order_items_1.item_id "
             "AND items_1.description != :description_1",
         )
+
+
+class SubqueryCriteriaTest(fixtures.DeclarativeMappedTest):
+    """test #10223"""
+
+    @classmethod
+    def setup_classes(cls):
+        Base = cls.DeclarativeBasic
+
+        class Temperature(Base):
+            __tablename__ = "temperature"
+            id = Column(Integer, primary_key=True)
+            pointless_flag = Column(Boolean)
+
+        class Color(Base):
+            __tablename__ = "color"
+            id = Column(Integer, primary_key=True)
+            name = Column(String(10))
+            temperature_id = Column(ForeignKey("temperature.id"))
+            temperature = relationship("Temperature")
+
+        room_connections = Table(
+            "room_connections",
+            Base.metadata,
+            Column(
+                "room_a_id",
+                Integer,
+                # mariadb does not like this FK constraint
+                # ForeignKey("room.id"),
+                primary_key=True,
+            ),
+            Column(
+                "room_b_id",
+                Integer,
+                # mariadb does not like this FK constraint
+                # ForeignKey("room.id"),
+                primary_key=True,
+            ),
+        )
+
+        class Room(Base):
+            __tablename__ = "room"
+            id = Column(Integer, primary_key=True)
+            token = Column(String(10))
+            color_id = Column(ForeignKey("color.id"))
+            color = relationship("Color")
+            connected_rooms = relationship(
+                "Room",
+                secondary=room_connections,
+                primaryjoin=id == room_connections.c.room_a_id,
+                secondaryjoin=id == room_connections.c.room_b_id,
+            )
+
+    @classmethod
+    def insert_data(cls, connection):
+        Room, Temperature, Color = cls.classes("Room", "Temperature", "Color")
+        with Session(connection) as session:
+            warm = Temperature(pointless_flag=True)
+            cool = Temperature(pointless_flag=True)
+            session.add_all([warm, cool])
+
+            red = Color(name="red", temperature=warm)
+            orange = Color(name="orange", temperature=warm)
+            blue = Color(name="blue", temperature=cool)
+            green = Color(name="green", temperature=cool)
+            session.add_all([red, orange, blue, green])
+
+            red1 = Room(token="Red-1", color=red)
+            red2 = Room(token="Red-2", color=red)
+            orange2 = Room(token="Orange-2", color=orange)
+            blue1 = Room(token="Blue-1", color=blue)
+            blue2 = Room(token="Blue-2", color=blue)
+            green1 = Room(token="Green-1", color=green)
+            red1.connected_rooms = [red2, blue1, green1]
+            red2.connected_rooms = [red1, blue2, orange2]
+            blue1.connected_rooms = [red1, blue2, green1]
+            blue2.connected_rooms = [red2, blue1, orange2]
+            session.add_all([red1, red2, blue1, blue2, green1, orange2])
+
+            session.commit()
+
+    @testing.variation(
+        "join_on_relationship", ["alone", "with_and", "no", "omit"]
+    )
+    def test_selectinload(self, join_on_relationship):
+        Room, Temperature, Color = self.classes("Room", "Temperature", "Color")
+        similar_color = aliased(Color)
+        subquery = (
+            select(Color.id)
+            .join(
+                similar_color,
+                similar_color.temperature_id == Color.temperature_id,
+            )
+            .where(similar_color.name == "red")
+        )
+
+        if join_on_relationship.alone:
+            subquery = subquery.join(Color.temperature).where(
+                Temperature.pointless_flag == True
+            )
+        elif join_on_relationship.with_and:
+            subquery = subquery.join(
+                Color.temperature.and_(Temperature.pointless_flag == True)
+            )
+        elif join_on_relationship.no:
+            subquery = subquery.join(
+                Temperature, Color.temperature_id == Temperature.id
+            ).where(Temperature.pointless_flag == True)
+        elif join_on_relationship.omit:
+            pass
+        else:
+            join_on_relationship.fail()
+
+        session = fixture_session()
+        room_result = session.scalars(
+            select(Room)
+            .order_by(Room.id)
+            .join(Room.color.and_(Color.name == "red"))
+            .options(
+                selectinload(
+                    Room.connected_rooms.and_(Room.color_id.in_(subquery))
+                )
+            )
+        ).unique()
+
+        self._assert_result(room_result)
+
+    def test_contains_eager(self):
+        Room, Temperature, Color = self.classes("Room", "Temperature", "Color")
+        similar_color = aliased(Color)
+        subquery = (
+            select(Color.id)
+            .join(
+                similar_color,
+                similar_color.temperature_id == Color.temperature_id,
+            )
+            .join(Color.temperature.and_(Temperature.pointless_flag == True))
+            .where(similar_color.name == "red")
+        )
+
+        room_alias = aliased(Room)
+        session = fixture_session()
+
+        room_result = session.scalars(
+            select(Room)
+            .order_by(Room.id)
+            .join(Room.color.and_(Color.name == "red"))
+            .join(
+                room_alias,
+                Room.connected_rooms.of_type(room_alias).and_(
+                    room_alias.color_id.in_(subquery)
+                ),
+            )
+            .options(contains_eager(Room.connected_rooms.of_type(room_alias)))
+        ).unique()
+
+        self._assert_result(room_result)
+
+    def _assert_result(self, room_result):
+        eq_(
+            [
+                (
+                    each_room.token,
+                    [room.token for room in each_room.connected_rooms],
+                )
+                for each_room in room_result
+            ],
+            [
+                ("Red-1", ["Red-2"]),
+                ("Red-2", ["Red-1", "Orange-2"]),
+            ],
+        )