From: Mike Bayer Date: Thu, 10 Aug 2023 22:26:45 +0000 (-0400) Subject: safe annotate QueryableAttribute inside of join() condition X-Git-Tag: rel_2_0_20~9 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=f09036680c723b16f250f95267800cddaf1a9a42;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git safe annotate QueryableAttribute inside of join() condition Fixed fundamental issue which prevented some forms of ORM "annotations" from taking place for subqueries which made use of :meth:`_sql.Select.join` against a relationship target. These annotations are used whenever a subquery is used in special situations such as within :meth:`_orm.PropComparator.and_` and other ORM-specific scenarios. Fixes: #10223 Change-Id: I40f04265a6caa0fdcbc9f1b121a35561ab4b1fcf --- diff --git a/doc/build/changelog/unreleased_14/10223.rst b/doc/build/changelog/unreleased_14/10223.rst new file mode 100644 index 0000000000..7c74424060 --- /dev/null +++ b/doc/build/changelog/unreleased_14/10223.rst @@ -0,0 +1,10 @@ +.. change:: + :tags: bug, orm + :tickets: 10223 + :versions: 2.0.20 + + Fixed fundamental issue which prevented some forms of ORM "annotations" + from taking place for subqueries which made use of :meth:`_sql.Select.join` + against a relationship target. These annotations are used whenever a + subquery is used in special situations such as within + :meth:`_orm.PropComparator.and_` and other ORM-specific scenarios. diff --git a/lib/sqlalchemy/orm/relationships.py b/lib/sqlalchemy/orm/relationships.py index e2ba698926..d3a8da042a 100644 --- a/lib/sqlalchemy/orm/relationships.py +++ b/lib/sqlalchemy/orm/relationships.py @@ -78,6 +78,7 @@ from ..sql import roles from ..sql import visitors from ..sql._typing import _ColumnExpressionArgument from ..sql._typing import _HasClauseElement +from ..sql.annotation import _safe_annotate from ..sql.elements import ColumnClause from ..sql.elements import ColumnElement from ..sql.util import _deep_annotate @@ -3297,7 +3298,7 @@ class JoinCondition: parentmapper_for_element is not self.prop.parent and parentmapper_for_element is not self.prop.mapper ): - return elem._annotate(annotations) + return _safe_annotate(elem, annotations) else: return elem diff --git a/lib/sqlalchemy/sql/annotation.py b/lib/sqlalchemy/sql/annotation.py index 016608a381..4ccde591a9 100644 --- a/lib/sqlalchemy/sql/annotation.py +++ b/lib/sqlalchemy/sql/annotation.py @@ -402,6 +402,18 @@ annotated_classes: Dict[ _SA = TypeVar("_SA", bound="SupportsAnnotations") +def _safe_annotate(to_annotate: _SA, annotations: _AnnotationDict) -> _SA: + try: + _annotate = to_annotate._annotate + except AttributeError: + # skip objects that don't actually have an `_annotate` + # attribute, namely QueryableAttribute inside of a join + # condition + return to_annotate + else: + return _annotate(annotations) + + def _deep_annotate( element: _SA, annotations: _AnnotationDict, @@ -455,7 +467,7 @@ def _deep_annotate( if annotate_callable: newelem = annotate_callable(to_annotate, annotations) else: - newelem = to_annotate._annotate(annotations) + newelem = _safe_annotate(to_annotate, annotations) else: newelem = elem diff --git a/test/orm/test_rel_fn.py b/test/orm/test_rel_fn.py index 5b2a15c13a..83ffff3c91 100644 --- a/test/orm/test_rel_fn.py +++ b/test/orm/test_rel_fn.py @@ -1309,6 +1309,28 @@ class LazyClauseTest(_JoinFixtures, fixtures.TestBase, AssertsCompiledSQL): class DeannotateCorrectlyTest(fixtures.TestBase): + def test_annotate_orm_join(self): + """test for #10223""" + from sqlalchemy.orm import declarative_base + + Base = declarative_base() + + class A(Base): + __tablename__ = "a" + id = Column(Integer, primary_key=True) + bs = relationship("B") + + class B(Base): + __tablename__ = "b" + id = Column(Integer, primary_key=True) + a_id = Column(ForeignKey(A.id)) + + stmt = select(A).join(A.bs) + + from sqlalchemy.sql import util + + util._deep_annotate(stmt, {"foo": "bar"}) + def test_pj_deannotates(self): from sqlalchemy.orm import declarative_base diff --git a/test/orm/test_relationship_criteria.py b/test/orm/test_relationship_criteria.py index d03b79e892..31423e5f4a 100644 --- a/test/orm/test_relationship_criteria.py +++ b/test/orm/test_relationship_criteria.py @@ -1,5 +1,8 @@ +from __future__ import annotations + import datetime import random +from typing import List from sqlalchemy import Column from sqlalchemy import DateTime @@ -15,15 +18,19 @@ from sqlalchemy import orm from sqlalchemy import select from sqlalchemy import sql from sqlalchemy import String +from sqlalchemy import Table from sqlalchemy import testing from sqlalchemy import union from sqlalchemy import update from sqlalchemy.orm import aliased from sqlalchemy.orm import column_property +from sqlalchemy.orm import contains_eager from sqlalchemy.orm import defer from sqlalchemy.orm import join as orm_join from sqlalchemy.orm import joinedload from sqlalchemy.orm import lazyload +from sqlalchemy.orm import Mapped +from sqlalchemy.orm import mapped_column from sqlalchemy.orm import registry from sqlalchemy.orm import relationship from sqlalchemy.orm import selectinload @@ -33,6 +40,7 @@ from sqlalchemy.orm import with_loader_criteria from sqlalchemy.orm.decl_api import declared_attr from sqlalchemy.testing import eq_ from sqlalchemy.testing import expect_raises_message +from sqlalchemy.testing import fixtures from sqlalchemy.testing.assertions import expect_raises from sqlalchemy.testing.assertsql import CompiledSQL from sqlalchemy.testing.fixtures import fixture_session @@ -2365,3 +2373,176 @@ class RelationshipCriteriaTest(_Fixtures, testing.AssertsCompiledSQL): "JOIN items AS items_1 ON items_1.id = order_items_1.item_id " "AND items_1.description != :description_1", ) + + +class SubqueryCriteriaTest(fixtures.DeclarativeMappedTest): + """test #10223""" + + @classmethod + def setup_classes(cls): + Base = cls.DeclarativeBasic + + class Temperature(Base): + __tablename__ = "temperature" + id: Mapped[int] = mapped_column(primary_key=True) + pointless_flag: Mapped[bool] + + class Color(Base): + __tablename__ = "color" + id: Mapped[int] = mapped_column(primary_key=True) + name: Mapped[str] = mapped_column(unique=True) + temperature_id: Mapped[int] = mapped_column( + ForeignKey("temperature.id") + ) + temperature: Mapped[Temperature] = relationship() + + room_connections = Table( + "room_connections", + Base.metadata, + Column( + "room_a_id", + Integer, + # mariadb does not like this FK constraint + # ForeignKey("room.id"), + primary_key=True, + ), + Column( + "room_b_id", + Integer, + # mariadb does not like this FK constraint + # ForeignKey("room.id"), + primary_key=True, + ), + ) + + class Room(Base): + __tablename__ = "room" + id: Mapped[int] = mapped_column(primary_key=True) + token: Mapped[str] = mapped_column(unique=True) + color_id: Mapped[int] = mapped_column(ForeignKey("color.id")) + color: Mapped[Color] = relationship() + connected_rooms: Mapped[List["Room"]] = relationship( # noqa: F821 + secondary=room_connections, + primaryjoin=id == room_connections.c.room_a_id, + secondaryjoin=id == room_connections.c.room_b_id, + ) + + @classmethod + def insert_data(cls, connection): + Room, Temperature, Color = cls.classes("Room", "Temperature", "Color") + with Session(connection) as session: + warm = Temperature(pointless_flag=True) + cool = Temperature(pointless_flag=True) + session.add_all([warm, cool]) + + red = Color(name="red", temperature=warm) + orange = Color(name="orange", temperature=warm) + blue = Color(name="blue", temperature=cool) + green = Color(name="green", temperature=cool) + session.add_all([red, orange, blue, green]) + + red1 = Room(token="Red-1", color=red) + red2 = Room(token="Red-2", color=red) + orange2 = Room(token="Orange-2", color=orange) + blue1 = Room(token="Blue-1", color=blue) + blue2 = Room(token="Blue-2", color=blue) + green1 = Room(token="Green-1", color=green) + red1.connected_rooms = [red2, blue1, green1] + red2.connected_rooms = [red1, blue2, orange2] + blue1.connected_rooms = [red1, blue2, green1] + blue2.connected_rooms = [red2, blue1, orange2] + session.add_all([red1, red2, blue1, blue2, green1, orange2]) + + session.commit() + + @testing.variation( + "join_on_relationship", ["alone", "with_and", "no", "omit"] + ) + def test_selectinload(self, join_on_relationship): + Room, Temperature, Color = self.classes("Room", "Temperature", "Color") + similar_color = aliased(Color) + subquery = ( + select(Color.id) + .join( + similar_color, + similar_color.temperature_id == Color.temperature_id, + ) + .where(similar_color.name == "red") + ) + + if join_on_relationship.alone: + subquery = subquery.join(Color.temperature).where( + Temperature.pointless_flag == True + ) + elif join_on_relationship.with_and: + subquery = subquery.join( + Color.temperature.and_(Temperature.pointless_flag == True) + ) + elif join_on_relationship.no: + subquery = subquery.join( + Temperature, Color.temperature_id == Temperature.id + ).where(Temperature.pointless_flag == True) + elif join_on_relationship.omit: + pass + else: + join_on_relationship.fail() + + session = fixture_session() + room_result = session.scalars( + select(Room) + .order_by(Room.id) + .join(Room.color.and_(Color.name == "red")) + .options( + selectinload( + Room.connected_rooms.and_(Room.color_id.in_(subquery)) + ) + ) + ).unique() + + self._assert_result(room_result) + + def test_contains_eager(self): + Room, Temperature, Color = self.classes("Room", "Temperature", "Color") + similar_color = aliased(Color) + subquery = ( + select(Color.id) + .join( + similar_color, + similar_color.temperature_id == Color.temperature_id, + ) + .join(Color.temperature.and_(Temperature.pointless_flag == True)) + .where(similar_color.name == "red") + ) + + room_alias = aliased(Room) + session = fixture_session() + + room_result = session.scalars( + select(Room) + .order_by(Room.id) + .join(Room.color.and_(Color.name == "red")) + .join( + room_alias, + Room.connected_rooms.of_type(room_alias).and_( + room_alias.color_id.in_(subquery) + ), + ) + .options(contains_eager(Room.connected_rooms.of_type(room_alias))) + ).unique() + + self._assert_result(room_result) + + def _assert_result(self, room_result): + eq_( + [ + ( + each_room.token, + [room.token for room in each_room.connected_rooms], + ) + for each_room in room_result + ], + [ + ("Red-1", ["Red-2"]), + ("Red-2", ["Red-1", "Orange-2"]), + ], + )