From: Mike Bayer Date: Thu, 10 Aug 2023 22:26:45 +0000 (-0400) Subject: safe annotate QueryableAttribute inside of join() condition X-Git-Tag: rel_1_4_50~11 X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=9e8b910c9a2b52de471c662caabb65e62cabf3c6;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git safe annotate QueryableAttribute inside of join() condition Fixed fundamental issue which prevented some forms of ORM "annotations" from taking place for subqueries which made use of :meth:`_sql.Select.join` against a relationship target. These annotations are used whenever a subquery is used in special situations such as within :meth:`_orm.PropComparator.and_` and other ORM-specific scenarios. Fixes: #10223 Change-Id: I40f04265a6caa0fdcbc9f1b121a35561ab4b1fcf (cherry picked from commit 6cfdc0743b7d1ebee3582f612a4f8acaa6ab42f9) --- diff --git a/doc/build/changelog/unreleased_14/10223.rst b/doc/build/changelog/unreleased_14/10223.rst new file mode 100644 index 0000000000..7c74424060 --- /dev/null +++ b/doc/build/changelog/unreleased_14/10223.rst @@ -0,0 +1,10 @@ +.. change:: + :tags: bug, orm + :tickets: 10223 + :versions: 2.0.20 + + Fixed fundamental issue which prevented some forms of ORM "annotations" + from taking place for subqueries which made use of :meth:`_sql.Select.join` + against a relationship target. These annotations are used whenever a + subquery is used in special situations such as within + :meth:`_orm.PropComparator.and_` and other ORM-specific scenarios. diff --git a/lib/sqlalchemy/sql/annotation.py b/lib/sqlalchemy/sql/annotation.py index 60e600ddf0..f98038d6a2 100644 --- a/lib/sqlalchemy/sql/annotation.py +++ b/lib/sqlalchemy/sql/annotation.py @@ -242,6 +242,18 @@ class Annotated(object): annotated_classes = {} +def _safe_annotate(to_annotate, annotations): + try: + _annotate = to_annotate._annotate + except AttributeError: + # skip objects that don't actually have an `_annotate` + # attribute, namely QueryableAttribute inside of a join + # condition + return to_annotate + else: + return _annotate(annotations) + + def _deep_annotate( element, annotations, exclude=None, detect_subquery_cols=False ): @@ -272,9 +284,11 @@ def _deep_annotate( newelem = elem._clone(clone=clone, **kw) elif annotations != elem._annotations: if detect_subquery_cols and elem._is_immutable: - newelem = elem._clone(clone=clone, **kw)._annotate(annotations) + newelem = _safe_annotate( + elem._clone(clone=clone, **kw), annotations + ) else: - newelem = elem._annotate(annotations) + newelem = _safe_annotate(elem, annotations) else: newelem = elem newelem._copy_internals(clone=clone) diff --git a/test/orm/test_rel_fn.py b/test/orm/test_rel_fn.py index 4d8eb88b91..a4e769d445 100644 --- a/test/orm/test_rel_fn.py +++ b/test/orm/test_rel_fn.py @@ -1243,6 +1243,28 @@ class LazyClauseTest(_JoinFixtures, fixtures.TestBase, AssertsCompiledSQL): class DeannotateCorrectlyTest(fixtures.TestBase): + def test_annotate_orm_join(self): + """test for #10223""" + from sqlalchemy.orm import declarative_base + + Base = declarative_base() + + class A(Base): + __tablename__ = "a" + id = Column(Integer, primary_key=True) + bs = relationship("B") + + class B(Base): + __tablename__ = "b" + id = Column(Integer, primary_key=True) + a_id = Column(ForeignKey(A.id)) + + stmt = select(A).join(A.bs) + + from sqlalchemy.sql import util + + util._deep_annotate(stmt, {"foo": "bar"}) + def test_pj_deannotates(self): from sqlalchemy.orm import declarative_base diff --git a/test/orm/test_relationship_criteria.py b/test/orm/test_relationship_criteria.py index e866fe0186..e1dc0ae29e 100644 --- a/test/orm/test_relationship_criteria.py +++ b/test/orm/test_relationship_criteria.py @@ -1,6 +1,7 @@ import datetime import random +from sqlalchemy import Boolean from sqlalchemy import Column from sqlalchemy import DateTime from sqlalchemy import event @@ -13,9 +14,11 @@ from sqlalchemy import orm from sqlalchemy import select from sqlalchemy import sql from sqlalchemy import String +from sqlalchemy import Table from sqlalchemy import testing from sqlalchemy.orm import aliased from sqlalchemy.orm import column_property +from sqlalchemy.orm import contains_eager from sqlalchemy.orm import defer from sqlalchemy.orm import join as orm_join from sqlalchemy.orm import joinedload @@ -29,6 +32,7 @@ from sqlalchemy.orm import with_loader_criteria from sqlalchemy.orm.decl_api import declared_attr from sqlalchemy.testing import eq_ from sqlalchemy.testing import expect_raises_message +from sqlalchemy.testing import fixtures from sqlalchemy.testing.assertions import expect_raises from sqlalchemy.testing.assertsql import CompiledSQL from sqlalchemy.testing.fixtures import fixture_session @@ -2137,3 +2141,175 @@ class RelationshipCriteriaTest(_Fixtures, testing.AssertsCompiledSQL): "JOIN items AS items_1 ON items_1.id = order_items_1.item_id " "AND items_1.description != :description_1", ) + + +class SubqueryCriteriaTest(fixtures.DeclarativeMappedTest): + """test #10223""" + + @classmethod + def setup_classes(cls): + Base = cls.DeclarativeBasic + + class Temperature(Base): + __tablename__ = "temperature" + id = Column(Integer, primary_key=True) + pointless_flag = Column(Boolean) + + class Color(Base): + __tablename__ = "color" + id = Column(Integer, primary_key=True) + name = Column(String(10)) + temperature_id = Column(ForeignKey("temperature.id")) + temperature = relationship("Temperature") + + room_connections = Table( + "room_connections", + Base.metadata, + Column( + "room_a_id", + Integer, + # mariadb does not like this FK constraint + # ForeignKey("room.id"), + primary_key=True, + ), + Column( + "room_b_id", + Integer, + # mariadb does not like this FK constraint + # ForeignKey("room.id"), + primary_key=True, + ), + ) + + class Room(Base): + __tablename__ = "room" + id = Column(Integer, primary_key=True) + token = Column(String(10)) + color_id = Column(ForeignKey("color.id")) + color = relationship("Color") + connected_rooms = relationship( + "Room", + secondary=room_connections, + primaryjoin=id == room_connections.c.room_a_id, + secondaryjoin=id == room_connections.c.room_b_id, + ) + + @classmethod + def insert_data(cls, connection): + Room, Temperature, Color = cls.classes("Room", "Temperature", "Color") + with Session(connection) as session: + warm = Temperature(pointless_flag=True) + cool = Temperature(pointless_flag=True) + session.add_all([warm, cool]) + + red = Color(name="red", temperature=warm) + orange = Color(name="orange", temperature=warm) + blue = Color(name="blue", temperature=cool) + green = Color(name="green", temperature=cool) + session.add_all([red, orange, blue, green]) + + red1 = Room(token="Red-1", color=red) + red2 = Room(token="Red-2", color=red) + orange2 = Room(token="Orange-2", color=orange) + blue1 = Room(token="Blue-1", color=blue) + blue2 = Room(token="Blue-2", color=blue) + green1 = Room(token="Green-1", color=green) + red1.connected_rooms = [red2, blue1, green1] + red2.connected_rooms = [red1, blue2, orange2] + blue1.connected_rooms = [red1, blue2, green1] + blue2.connected_rooms = [red2, blue1, orange2] + session.add_all([red1, red2, blue1, blue2, green1, orange2]) + + session.commit() + + @testing.variation( + "join_on_relationship", ["alone", "with_and", "no", "omit"] + ) + def test_selectinload(self, join_on_relationship): + Room, Temperature, Color = self.classes("Room", "Temperature", "Color") + similar_color = aliased(Color) + subquery = ( + select(Color.id) + .join( + similar_color, + similar_color.temperature_id == Color.temperature_id, + ) + .where(similar_color.name == "red") + ) + + if join_on_relationship.alone: + subquery = subquery.join(Color.temperature).where( + Temperature.pointless_flag == True + ) + elif join_on_relationship.with_and: + subquery = subquery.join( + Color.temperature.and_(Temperature.pointless_flag == True) + ) + elif join_on_relationship.no: + subquery = subquery.join( + Temperature, Color.temperature_id == Temperature.id + ).where(Temperature.pointless_flag == True) + elif join_on_relationship.omit: + pass + else: + join_on_relationship.fail() + + session = fixture_session() + room_result = session.scalars( + select(Room) + .order_by(Room.id) + .join(Room.color.and_(Color.name == "red")) + .options( + selectinload( + Room.connected_rooms.and_(Room.color_id.in_(subquery)) + ) + ) + ).unique() + + self._assert_result(room_result) + + def test_contains_eager(self): + Room, Temperature, Color = self.classes("Room", "Temperature", "Color") + similar_color = aliased(Color) + subquery = ( + select(Color.id) + .join( + similar_color, + similar_color.temperature_id == Color.temperature_id, + ) + .join(Color.temperature.and_(Temperature.pointless_flag == True)) + .where(similar_color.name == "red") + ) + + room_alias = aliased(Room) + session = fixture_session() + + room_result = session.scalars( + select(Room) + .order_by(Room.id) + .join(Room.color.and_(Color.name == "red")) + .join( + room_alias, + Room.connected_rooms.of_type(room_alias).and_( + room_alias.color_id.in_(subquery) + ), + ) + .options(contains_eager(Room.connected_rooms.of_type(room_alias))) + ).unique() + + self._assert_result(room_result) + + def _assert_result(self, room_result): + eq_( + [ + ( + each_room.token, + [room.token for room in each_room.connected_rooms], + ) + for each_room in room_result + ], + [ + ("Red-1", ["Red-2"]), + ("Red-2", ["Red-1", "Orange-2"]), + ], + )