]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
add recursion check for with_loader_criteria() option
authorMike Bayer <mike_mp@zzzcomputing.com>
Wed, 22 Dec 2021 20:33:11 +0000 (15:33 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Wed, 22 Dec 2021 20:46:48 +0000 (15:46 -0500)
Fixed recursion overflow which could occur within ORM statement compilation
when using either the :func:`_orm.with_loader_criteria` feature or the the
:meth:`_orm.PropComparator.and_` method within a loader strategy in
conjunction with a subquery which referred to the same entity being altered
by the criteria option, or loaded by the loader strategy.  A check for
coming across the same loader criteria option in a recursive fashion has
been added to accommodate for this scenario.

Fixes: #7491
Change-Id: I8701332717c45a21948ea4788a3058c0fbbf03a7

doc/build/changelog/unreleased_14/7491.rst [new file with mode: 0644]
lib/sqlalchemy/orm/context.py
lib/sqlalchemy/orm/util.py
lib/sqlalchemy/sql/annotation.py
test/orm/test_relationship_criteria.py

diff --git a/doc/build/changelog/unreleased_14/7491.rst b/doc/build/changelog/unreleased_14/7491.rst
new file mode 100644 (file)
index 0000000..f1a1952
--- /dev/null
@@ -0,0 +1,12 @@
+.. change::
+    :tags: bug, orm
+    :tickets: 7491
+
+    Fixed recursion overflow which could occur within ORM statement compilation
+    when using either the :func:`_orm.with_loader_criteria` feature or the the
+    :meth:`_orm.PropComparator.and_` method within a loader strategy in
+    conjunction with a subquery which referred to the same entity being altered
+    by the criteria option, or loaded by the loader strategy.  A check for
+    coming across the same loader criteria option in a recursive fashion has
+    been added to accommodate for this scenario.
+
index 61b95728057b4b5b02cf1f94d6fe6efed063b371..4e2586203c2f9a834e7cdfa65b3cd9cd3e6db10e 100644 (file)
@@ -2137,7 +2137,8 @@ class ORMSelectCompileState(ORMCompileState, SelectState):
                 for ae in self.global_attributes[
                     ("additional_entity_criteria", ext_info.mapper)
                 ]
-                if ae.include_aliases or ae.entity is ext_info
+                if (ae.include_aliases or ae.entity is ext_info)
+                and ae._should_include(self)
             )
         else:
             return ()
index 140464b87cbc7ebfd97ac259747ea990e9713fda..fef65f73c2e4f9159835ece49d8835433193e481 100644 (file)
@@ -1149,11 +1149,24 @@ class LoaderCriteriaOption(CriteriaOption):
                 else:
                     stack.extend(subclass.__subclasses__())
 
+    def _should_include(self, compile_state):
+        if (
+            compile_state.select_statement._annotations.get(
+                "for_loader_criteria", None
+            )
+            is self
+        ):
+            return False
+        return True
+
     def _resolve_where_criteria(self, ext_info):
         if self.deferred_where_criteria:
-            return self.where_criteria._resolve_with_args(ext_info.entity)
+            crit = self.where_criteria._resolve_with_args(ext_info.entity)
         else:
-            return self.where_criteria
+            crit = self.where_criteria
+        return sql_util._deep_annotate(
+            crit, {"for_loader_criteria": self}, detect_subquery_cols=True
+        )
 
     def process_compile_state_replaced_entities(
         self, compile_state, mapper_entities
index 519a3103bb713c30a4269c8532bdb01643224633..1706da44e0c6707de714e755f33e2a12c88fe34d 100644 (file)
@@ -243,7 +243,9 @@ class Annotated:
 annotated_classes = {}
 
 
-def _deep_annotate(element, annotations, exclude=None):
+def _deep_annotate(
+    element, annotations, exclude=None, detect_subquery_cols=False
+):
     """Deep copy the given ClauseElement, annotating each element
     with the given annotations dictionary.
 
@@ -257,6 +259,7 @@ def _deep_annotate(element, annotations, exclude=None):
     cloned_ids = {}
 
     def clone(elem, **kw):
+        kw["detect_subquery_cols"] = detect_subquery_cols
         id_ = id(elem)
 
         if id_ in cloned_ids:
@@ -267,9 +270,12 @@ def _deep_annotate(element, annotations, exclude=None):
             and hasattr(elem, "proxy_set")
             and elem.proxy_set.intersection(exclude)
         ):
-            newelem = elem._clone(**kw)
+            newelem = elem._clone(clone=clone, **kw)
         elif annotations != elem._annotations:
-            newelem = elem._annotate(annotations)
+            if detect_subquery_cols and elem._is_immutable:
+                newelem = elem._clone(clone=clone, **kw)._annotate(annotations)
+            else:
+                newelem = elem._annotate(annotations)
         else:
             newelem = elem
         newelem._copy_internals(clone=clone)
index b5c9be03829aa75d6b6a0d7508911de63b03818c..d8d3844cd1136e8c2f576389154ff6d86a01ef33 100644 (file)
@@ -4,6 +4,7 @@ import random
 from sqlalchemy import Column
 from sqlalchemy import DateTime
 from sqlalchemy import event
+from sqlalchemy import exc as sa_exc
 from sqlalchemy import ForeignKey
 from sqlalchemy import func
 from sqlalchemy import Integer
@@ -25,6 +26,7 @@ from sqlalchemy.orm import subqueryload
 from sqlalchemy.orm import with_loader_criteria
 from sqlalchemy.orm.decl_api import declared_attr
 from sqlalchemy.testing import eq_
+from sqlalchemy.testing.assertions import expect_raises
 from sqlalchemy.testing.assertsql import CompiledSQL
 from sqlalchemy.testing.fixtures import fixture_session
 from sqlalchemy.testing.util import resolve_lambda
@@ -250,6 +252,50 @@ class LoaderCriteriaTest(_Fixtures, testing.AssertsCompiledSQL):
             "WHERE users.name != :name_1",
         )
 
+    def test_with_loader_criteria_recursion_check_scalar_subq(
+        self, user_address_fixture
+    ):
+        """test #7491"""
+
+        User, Address = user_address_fixture
+        subq = select(Address).where(Address.id == 8).scalar_subquery()
+        stmt = (
+            select(User)
+            .join(Address)
+            .options(with_loader_criteria(Address, Address.id == subq))
+        )
+        self.assert_compile(
+            stmt,
+            "SELECT users.id, users.name FROM users JOIN addresses "
+            "ON users.id = addresses.user_id AND addresses.id = "
+            "(SELECT addresses.id, addresses.user_id, "
+            "addresses.email_address FROM addresses "
+            "WHERE addresses.id = :id_1)",
+        )
+
+    def test_with_loader_criteria_recursion_check_from_subq(
+        self, user_address_fixture
+    ):
+        """test #7491"""
+
+        User, Address = user_address_fixture
+        subq = select(Address).where(Address.id == 8).subquery()
+        stmt = (
+            select(User)
+            .join(Address)
+            .options(with_loader_criteria(Address, Address.id == subq.c.id))
+        )
+        # note this query is incorrect SQL right now.   This is a current
+        # artifact of how with_loader_criteria() is used and may be considered
+        # a bug at some point, in which case if fixed this query can be
+        # changed.  the main thing we are testing at the moment is that
+        # there is not a recursion overflow.
+        self.assert_compile(
+            stmt,
+            "SELECT users.id, users.name FROM users JOIN addresses "
+            "ON users.id = addresses.user_id AND addresses.id = anon_1.id",
+        )
+
     def test_select_mapper_columns_mapper_criteria(self, user_address_fixture):
         User, Address = user_address_fixture
 
@@ -1300,6 +1346,78 @@ class RelationshipCriteriaTest(_Fixtures, testing.AssertsCompiledSQL):
                 ),
             )
 
+    @testing.combinations(
+        (joinedload, False),
+        (lazyload, True),
+        (subqueryload, False),
+        (selectinload, True),
+        argnames="opt,results_supported",
+    )
+    def test_loader_criteria_subquery_w_same_entity(
+        self, user_address_fixture, opt, results_supported
+    ):
+        """test #7491.
+
+        note this test also uses the not-quite-supported form of subquery
+        criteria introduced by #7489. where we also have to clone
+        the subquery linked only from a column criteria.  this required
+        additional changes to the _annotate() method that is also
+        test here, which is why two of the loader strategies still fail;
+        we're just testing that there's no recursion overflow with this
+        very particular form.
+
+        """
+        User, Address = user_address_fixture
+
+        s = Session(testing.db, future=True)
+
+        def go(value):
+            subq = (
+                select(Address.id)
+                .where(Address.email_address != value)
+                .subquery()
+            )
+            stmt = (
+                select(User)
+                .options(
+                    # subquery here would need to be added to the FROM
+                    # clause.  this isn't quite supported and won't work
+                    # right now with joinedoad() or subqueryload().
+                    opt(User.addresses.and_(Address.id == subq.c.id)),
+                )
+                .order_by(User.id)
+            )
+            result = s.execute(stmt)
+            return result
+
+        for value in (
+            "ed@wood.com",
+            "ed@lala.com",
+            "ed@wood.com",
+            "ed@lala.com",
+        ):
+            s.close()
+
+            if not results_supported:
+                # for joinedload and subqueryload, the query generated here
+                # is invalid right now; this is because it's already not
+                # quite a supported pattern to refer to a subquery-bound
+                # column in loader criteria.  However, the main thing we want
+                # to prevent here is the recursion overflow, so make sure
+                # we get a DBAPI error at least indicating compilation
+                # succeeded.
+                with expect_raises(sa_exc.DBAPIError):
+                    go(value).scalars().unique().all()
+            else:
+                result = go(value).scalars().unique().all()
+
+                eq_(
+                    result,
+                    self._user_minus_edwood(*user_address_fixture)
+                    if value == "ed@wood.com"
+                    else self._user_minus_edlala(*user_address_fixture),
+                )
+
     @testing.combinations((True,), (False,), argnames="use_compiled_cache")
     def test_selectinload_nested_criteria(
         self, user_order_item_fixture, use_compiled_cache