From: Mike Bayer Date: Wed, 22 Dec 2021 20:33:11 +0000 (-0500) Subject: add recursion check for with_loader_criteria() option X-Git-Tag: rel_2_0_0b1~585^2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=c66c6d1aeff92f838740b7745a9c2a47852949d6;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git add recursion check for with_loader_criteria() option Fixed recursion overflow which could occur within ORM statement compilation when using either the :func:`_orm.with_loader_criteria` feature or the the :meth:`_orm.PropComparator.and_` method within a loader strategy in conjunction with a subquery which referred to the same entity being altered by the criteria option, or loaded by the loader strategy. A check for coming across the same loader criteria option in a recursive fashion has been added to accommodate for this scenario. Fixes: #7491 Change-Id: I8701332717c45a21948ea4788a3058c0fbbf03a7 --- diff --git a/doc/build/changelog/unreleased_14/7491.rst b/doc/build/changelog/unreleased_14/7491.rst new file mode 100644 index 0000000000..f1a19525bb --- /dev/null +++ b/doc/build/changelog/unreleased_14/7491.rst @@ -0,0 +1,12 @@ +.. change:: + :tags: bug, orm + :tickets: 7491 + + Fixed recursion overflow which could occur within ORM statement compilation + when using either the :func:`_orm.with_loader_criteria` feature or the the + :meth:`_orm.PropComparator.and_` method within a loader strategy in + conjunction with a subquery which referred to the same entity being altered + by the criteria option, or loaded by the loader strategy. A check for + coming across the same loader criteria option in a recursive fashion has + been added to accommodate for this scenario. + diff --git a/lib/sqlalchemy/orm/context.py b/lib/sqlalchemy/orm/context.py index 61b9572805..4e2586203c 100644 --- a/lib/sqlalchemy/orm/context.py +++ b/lib/sqlalchemy/orm/context.py @@ -2137,7 +2137,8 @@ class ORMSelectCompileState(ORMCompileState, SelectState): for ae in self.global_attributes[ ("additional_entity_criteria", ext_info.mapper) ] - if ae.include_aliases or ae.entity is ext_info + if (ae.include_aliases or ae.entity is ext_info) + and ae._should_include(self) ) else: return () diff --git a/lib/sqlalchemy/orm/util.py b/lib/sqlalchemy/orm/util.py index 140464b87c..fef65f73c2 100644 --- a/lib/sqlalchemy/orm/util.py +++ b/lib/sqlalchemy/orm/util.py @@ -1149,11 +1149,24 @@ class LoaderCriteriaOption(CriteriaOption): else: stack.extend(subclass.__subclasses__()) + def _should_include(self, compile_state): + if ( + compile_state.select_statement._annotations.get( + "for_loader_criteria", None + ) + is self + ): + return False + return True + def _resolve_where_criteria(self, ext_info): if self.deferred_where_criteria: - return self.where_criteria._resolve_with_args(ext_info.entity) + crit = self.where_criteria._resolve_with_args(ext_info.entity) else: - return self.where_criteria + crit = self.where_criteria + return sql_util._deep_annotate( + crit, {"for_loader_criteria": self}, detect_subquery_cols=True + ) def process_compile_state_replaced_entities( self, compile_state, mapper_entities diff --git a/lib/sqlalchemy/sql/annotation.py b/lib/sqlalchemy/sql/annotation.py index 519a3103bb..1706da44e0 100644 --- a/lib/sqlalchemy/sql/annotation.py +++ b/lib/sqlalchemy/sql/annotation.py @@ -243,7 +243,9 @@ class Annotated: annotated_classes = {} -def _deep_annotate(element, annotations, exclude=None): +def _deep_annotate( + element, annotations, exclude=None, detect_subquery_cols=False +): """Deep copy the given ClauseElement, annotating each element with the given annotations dictionary. @@ -257,6 +259,7 @@ def _deep_annotate(element, annotations, exclude=None): cloned_ids = {} def clone(elem, **kw): + kw["detect_subquery_cols"] = detect_subquery_cols id_ = id(elem) if id_ in cloned_ids: @@ -267,9 +270,12 @@ def _deep_annotate(element, annotations, exclude=None): and hasattr(elem, "proxy_set") and elem.proxy_set.intersection(exclude) ): - newelem = elem._clone(**kw) + newelem = elem._clone(clone=clone, **kw) elif annotations != elem._annotations: - newelem = elem._annotate(annotations) + if detect_subquery_cols and elem._is_immutable: + newelem = elem._clone(clone=clone, **kw)._annotate(annotations) + else: + newelem = elem._annotate(annotations) else: newelem = elem newelem._copy_internals(clone=clone) diff --git a/test/orm/test_relationship_criteria.py b/test/orm/test_relationship_criteria.py index b5c9be0382..d8d3844cd1 100644 --- a/test/orm/test_relationship_criteria.py +++ b/test/orm/test_relationship_criteria.py @@ -4,6 +4,7 @@ import random from sqlalchemy import Column from sqlalchemy import DateTime from sqlalchemy import event +from sqlalchemy import exc as sa_exc from sqlalchemy import ForeignKey from sqlalchemy import func from sqlalchemy import Integer @@ -25,6 +26,7 @@ from sqlalchemy.orm import subqueryload from sqlalchemy.orm import with_loader_criteria from sqlalchemy.orm.decl_api import declared_attr from sqlalchemy.testing import eq_ +from sqlalchemy.testing.assertions import expect_raises from sqlalchemy.testing.assertsql import CompiledSQL from sqlalchemy.testing.fixtures import fixture_session from sqlalchemy.testing.util import resolve_lambda @@ -250,6 +252,50 @@ class LoaderCriteriaTest(_Fixtures, testing.AssertsCompiledSQL): "WHERE users.name != :name_1", ) + def test_with_loader_criteria_recursion_check_scalar_subq( + self, user_address_fixture + ): + """test #7491""" + + User, Address = user_address_fixture + subq = select(Address).where(Address.id == 8).scalar_subquery() + stmt = ( + select(User) + .join(Address) + .options(with_loader_criteria(Address, Address.id == subq)) + ) + self.assert_compile( + stmt, + "SELECT users.id, users.name FROM users JOIN addresses " + "ON users.id = addresses.user_id AND addresses.id = " + "(SELECT addresses.id, addresses.user_id, " + "addresses.email_address FROM addresses " + "WHERE addresses.id = :id_1)", + ) + + def test_with_loader_criteria_recursion_check_from_subq( + self, user_address_fixture + ): + """test #7491""" + + User, Address = user_address_fixture + subq = select(Address).where(Address.id == 8).subquery() + stmt = ( + select(User) + .join(Address) + .options(with_loader_criteria(Address, Address.id == subq.c.id)) + ) + # note this query is incorrect SQL right now. This is a current + # artifact of how with_loader_criteria() is used and may be considered + # a bug at some point, in which case if fixed this query can be + # changed. the main thing we are testing at the moment is that + # there is not a recursion overflow. + self.assert_compile( + stmt, + "SELECT users.id, users.name FROM users JOIN addresses " + "ON users.id = addresses.user_id AND addresses.id = anon_1.id", + ) + def test_select_mapper_columns_mapper_criteria(self, user_address_fixture): User, Address = user_address_fixture @@ -1300,6 +1346,78 @@ class RelationshipCriteriaTest(_Fixtures, testing.AssertsCompiledSQL): ), ) + @testing.combinations( + (joinedload, False), + (lazyload, True), + (subqueryload, False), + (selectinload, True), + argnames="opt,results_supported", + ) + def test_loader_criteria_subquery_w_same_entity( + self, user_address_fixture, opt, results_supported + ): + """test #7491. + + note this test also uses the not-quite-supported form of subquery + criteria introduced by #7489. where we also have to clone + the subquery linked only from a column criteria. this required + additional changes to the _annotate() method that is also + test here, which is why two of the loader strategies still fail; + we're just testing that there's no recursion overflow with this + very particular form. + + """ + User, Address = user_address_fixture + + s = Session(testing.db, future=True) + + def go(value): + subq = ( + select(Address.id) + .where(Address.email_address != value) + .subquery() + ) + stmt = ( + select(User) + .options( + # subquery here would need to be added to the FROM + # clause. this isn't quite supported and won't work + # right now with joinedoad() or subqueryload(). + opt(User.addresses.and_(Address.id == subq.c.id)), + ) + .order_by(User.id) + ) + result = s.execute(stmt) + return result + + for value in ( + "ed@wood.com", + "ed@lala.com", + "ed@wood.com", + "ed@lala.com", + ): + s.close() + + if not results_supported: + # for joinedload and subqueryload, the query generated here + # is invalid right now; this is because it's already not + # quite a supported pattern to refer to a subquery-bound + # column in loader criteria. However, the main thing we want + # to prevent here is the recursion overflow, so make sure + # we get a DBAPI error at least indicating compilation + # succeeded. + with expect_raises(sa_exc.DBAPIError): + go(value).scalars().unique().all() + else: + result = go(value).scalars().unique().all() + + eq_( + result, + self._user_minus_edwood(*user_address_fixture) + if value == "ed@wood.com" + else self._user_minus_edlala(*user_address_fixture), + ) + @testing.combinations((True,), (False,), argnames="use_compiled_cache") def test_selectinload_nested_criteria( self, user_order_item_fixture, use_compiled_cache