]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
propagate regular execution_options to secondary eager loaders
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 13 Aug 2023 15:16:03 +0000 (11:16 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 14 Aug 2023 19:19:29 +0000 (15:19 -0400)
Fixed fairly major issue where execution options passed to
:meth:`_orm.Session.execute`, as well as execution options local to the ORM
executed statement itself, would not be propagated along to eager loaders
such as that of :func:`_orm.selectinload`, :func:`_orm.immediateload`, and
:meth:`_orm.subqueryload`, making it impossible to do things such as
disabling the cache for a single statement or using
``schema_translate_map`` for a single statement, as well as the use of
user-custom execution options.   A change has been made where **all**
user-facing execution options present for :meth:`_orm.Session.execute` will
be propagated along to additional loaders.

As part of this change, the warning for "excessively deep" eager loaders
leading to caching being disabled can be silenced on a per-statement
basis by sending ``execution_options={"compiled_cache": None}`` to
:meth:`_orm.Session.execute`, which will disable caching for the full
series of statements within that scope.

Fixes: #10231
Change-Id: I5304d3af0b78e1b4593c3558f117b7ac10b499ae

doc/build/changelog/unreleased_20/10231.rst [new file with mode: 0644]
lib/sqlalchemy/orm/bulk_persistence.py
lib/sqlalchemy/orm/context.py
lib/sqlalchemy/orm/strategies.py
test/orm/dml/test_update_delete_where.py
test/orm/test_recursive_loaders.py
test/orm/test_session.py

diff --git a/doc/build/changelog/unreleased_20/10231.rst b/doc/build/changelog/unreleased_20/10231.rst
new file mode 100644 (file)
index 0000000..35e8c4f
--- /dev/null
@@ -0,0 +1,20 @@
+.. change::
+    :tags: orm, bug
+    :tickets: 10231
+
+    Fixed fairly major issue where execution options passed to
+    :meth:`_orm.Session.execute`, as well as execution options local to the ORM
+    executed statement itself, would not be propagated along to eager loaders
+    such as that of :func:`_orm.selectinload`, :func:`_orm.immediateload`, and
+    :meth:`_orm.subqueryload`, making it impossible to do things such as
+    disabling the cache for a single statement or using
+    ``schema_translate_map`` for a single statement, as well as the use of
+    user-custom execution options.   A change has been made where **all**
+    user-facing execution options present for :meth:`_orm.Session.execute` will
+    be propagated along to additional loaders.
+
+    As part of this change, the warning for "excessively deep" eager loaders
+    leading to caching being disabled can be silenced on a per-statement
+    basis by sending ``execution_options={"compiled_cache": None}`` to
+    :meth:`_orm.Session.execute`, which will disable caching for the full
+    series of statements within that scope.
index d38dfa9ce19fcd2350d95f950575fecf4a580758..6b35c4a50dd647c95c5d4363e8849f3217392c41 100644 (file)
@@ -525,6 +525,7 @@ class ORMDMLState(AbstractORMCompileState):
                 dml_level_statement,
                 _adapt_on_names=False,
             )
+            fs = fs.execution_options(**orm_level_statement._execution_options)
             fs = fs.options(*orm_level_statement._with_options)
             self.select_statement = fs
             self.from_statement_ctx = (
index 63c4e86c63845832420852884e9ca3edcc5d29ed..79b43f5fe7dd05001223d75ada73690f73562d6d 100644 (file)
@@ -494,6 +494,13 @@ class ORMCompileState(AbstractORMCompileState):
         #    this will disable the ResultSetMetadata._adapt_to_context()
         #    step which we don't need, as we have result processors cached
         #    against the original SELECT statement before caching.
+
+        if "sa_top_level_orm_context" in execution_options:
+            ctx = execution_options["sa_top_level_orm_context"]
+            execution_options = ctx.query._execution_options.merge_with(
+                ctx.execution_options, execution_options
+            )
+
         if not execution_options:
             execution_options = _orm_load_exec_options
         else:
@@ -514,7 +521,8 @@ class ORMCompileState(AbstractORMCompileState):
                 "Loader depth for query is excessively deep; caching will "
                 "be disabled for additional loaders.  Consider using the "
                 "recursion_depth feature for deeply nested recursive eager "
-                "loaders."
+                "loaders.  Use the compiled_cache=None execution option to "
+                "skip this warning."
             )
             execution_options = execution_options.union(
                 {"compiled_cache": None}
index 4e677410dc05774b8bf51f975e18a429135c62a4..a0e092988386c95630f1610bbc201e3c89724884 100644 (file)
@@ -1910,11 +1910,12 @@ class SubqueryLoader(PostLoader):
 
         q = query.Query(effective_entity)
 
-        q._execution_options = q._execution_options.union(
+        q._execution_options = context.query._execution_options.merge_with(
+            context.execution_options,
             {
                 ("orig_query", SubqueryLoader): orig_query,
                 ("subquery_paths", None): (subq_path, rewritten_path),
-            }
+            },
         )
 
         q = q._set_enable_single_crit(False)
@@ -2948,6 +2949,7 @@ class SelectInLoader(PostLoader, util.MemoizedSlots):
         ) = self._setup_for_recursion(
             context, path, loadopt, join_depth=self.join_depth
         )
+
         if not run_loader:
             return
 
index 7f76d735d35b594aedd5341dc76611a9d6f18d4f..89d1e5c7fb2cb8281eb40536881d07ee99e92841 100644 (file)
@@ -21,8 +21,10 @@ from sqlalchemy import update
 from sqlalchemy import values
 from sqlalchemy.orm import backref
 from sqlalchemy.orm import exc as orm_exc
+from sqlalchemy.orm import immediateload
 from sqlalchemy.orm import joinedload
 from sqlalchemy.orm import relationship
+from sqlalchemy.orm import selectinload
 from sqlalchemy.orm import Session
 from sqlalchemy.orm import sessionmaker
 from sqlalchemy.orm import synonym
@@ -1120,6 +1122,72 @@ class UpdateDeleteTest(fixtures.MappedTest):
             list(zip([25, 37, 29, 27])),
         )
 
+    @testing.requires.update_returning
+    @testing.combinations(
+        selectinload,
+        immediateload,
+        argnames="loader_fn",
+    )
+    @testing.variation("opt_location", ["statement", "execute"])
+    def test_update_returning_eagerload_propagate(
+        self, loader_fn, connection, opt_location
+    ):
+        User = self.classes.User
+
+        catch_opts = []
+
+        @event.listens_for(connection, "before_cursor_execute")
+        def before_cursor_execute(
+            conn, cursor, statement, parameters, context, executemany
+        ):
+            catch_opts.append(
+                {
+                    k: v
+                    for k, v in context.execution_options.items()
+                    if isinstance(k, str)
+                    and k[0] != "_"
+                    and k not in ("sa_top_level_orm_context",)
+                }
+            )
+
+        sess = Session(connection)
+
+        stmt = (
+            update(User)
+            .where(User.age > 29)
+            .values({"age": User.age - 10})
+            .returning(User)
+            .options(loader_fn(User.addresses))
+        )
+
+        if opt_location.execute:
+            opts = {
+                "compiled_cache": None,
+                "user_defined": "opt1",
+                "schema_translate_map": {"foo": "bar"},
+            }
+            result = sess.scalars(
+                stmt,
+                execution_options=opts,
+            )
+        elif opt_location.statement:
+            opts = {
+                "user_defined": "opt1",
+                "schema_translate_map": {"foo": "bar"},
+            }
+            stmt = stmt.execution_options(**opts)
+            result = sess.scalars(stmt)
+        else:
+            result = ()
+            opts = None
+            opt_location.fail()
+
+        for u1 in result:
+            u1.addresses
+
+        for elem in catch_opts:
+            eq_(elem, opts)
+
     @testing.combinations(True, False, argnames="implicit_returning")
     def test_update_fetch_returning(self, implicit_returning):
         if implicit_returning:
index 53f0166c0d12370d0c0a68f6e2120bf00a7ba4bf..10582e71131f5c01e092ab672c6c434d35ace3d4 100644 (file)
@@ -1,6 +1,7 @@
 import sqlalchemy as sa
 from sqlalchemy import ForeignKey
 from sqlalchemy import Integer
+from sqlalchemy import literal_column
 from sqlalchemy import select
 from sqlalchemy import String
 from sqlalchemy import testing
@@ -259,8 +260,9 @@ class DeepRecursiveTest(_NodeTest, fixtures.MappedTest):
 
     @testing.combinations(selectinload, immediateload, argnames="loader_fn")
     @testing.combinations(4, 9, 12, 25, 41, 55, argnames="depth")
+    @testing.variation("disable_cache", [True, False])
     def test_warning_w_no_recursive_opt(
-        self, loader_fn, depth, limited_cache_conn
+        self, loader_fn, depth, limited_cache_conn, disable_cache
     ):
         connection = limited_cache_conn(27)
 
@@ -273,22 +275,35 @@ class DeepRecursiveTest(_NodeTest, fixtures.MappedTest):
                 .options(self._stack_loaders(loader_fn, depth))
             )
 
+            if disable_cache:
+                exec_opts = dict(compiled_cache=None)
+            else:
+                exec_opts = {}
+
             # note this is a magic number, it's not important that it's exact,
             # just that when someone makes a huge recursive thing,
             # it warns
-            if depth > 8:
+            if depth > 8 and not disable_cache:
                 with expect_warnings(
                     "Loader depth for query is excessively deep; "
                     "caching will be disabled for additional loaders."
                 ):
                     with Session(connection) as s:
-                        result = s.scalars(stmt)
+                        result = s.scalars(stmt, execution_options=exec_opts)
                         self._assert_depth(result.one(), depth)
             else:
                 with Session(connection) as s:
-                    result = s.scalars(stmt)
+                    result = s.scalars(stmt, execution_options=exec_opts)
                     self._assert_depth(result.one(), depth)
 
+        if disable_cache:
+            clen = len(connection.engine._compiled_cache)
+            assert clen == 0
+            # limited_cache_conn wants to confirm the cache was used,
+            # so popualte in the case that we know we didn't use it
+            connection.execute(select(1))
+            connection.execute(select(1).where(literal_column("1") == 1))
+
 
 # TODO:
 # we should do another set of tests using Node -> Edge -> Node
index 6d599d68eaf5cbd5dc999b0594a66d71fbef34d6..b304ac574540085e2f5837fbfc35657fe617b5c9 100644 (file)
@@ -22,12 +22,15 @@ from sqlalchemy.orm import attributes
 from sqlalchemy.orm import backref
 from sqlalchemy.orm import close_all_sessions
 from sqlalchemy.orm import exc as orm_exc
+from sqlalchemy.orm import immediateload
 from sqlalchemy.orm import make_transient
 from sqlalchemy.orm import make_transient_to_detached
 from sqlalchemy.orm import object_session
 from sqlalchemy.orm import relationship
+from sqlalchemy.orm import selectinload
 from sqlalchemy.orm import Session
 from sqlalchemy.orm import sessionmaker
+from sqlalchemy.orm import subqueryload
 from sqlalchemy.orm import was_deleted
 from sqlalchemy.testing import assert_raises
 from sqlalchemy.testing import assert_raises_message
@@ -2295,6 +2298,66 @@ class NewStyleExecutionTest(_fixtures.FixtureTest):
         if construct.select:
             result.all()
 
+    @testing.combinations(
+        selectinload,
+        immediateload,
+        subqueryload,
+        argnames="loader_fn",
+    )
+    @testing.variation("opt_location", ["statement", "execute"])
+    def test_eagerloader_exec_option(
+        self, loader_fn, connection, opt_location
+    ):
+        User = self.classes.User
+
+        catch_opts = []
+
+        @event.listens_for(connection, "before_cursor_execute")
+        def before_cursor_execute(
+            conn, cursor, statement, parameters, context, executemany
+        ):
+            catch_opts.append(
+                {
+                    k: v
+                    for k, v in context.execution_options.items()
+                    if isinstance(k, str)
+                    and k[0] != "_"
+                    and k not in ("sa_top_level_orm_context",)
+                }
+            )
+
+        sess = Session(connection)
+
+        stmt = select(User).options(loader_fn(User.addresses))
+
+        if opt_location.execute:
+            opts = {
+                "compiled_cache": None,
+                "user_defined": "opt1",
+                "schema_translate_map": {"foo": "bar"},
+            }
+            result = sess.scalars(
+                stmt,
+                execution_options=opts,
+            )
+        elif opt_location.statement:
+            opts = {
+                "user_defined": "opt1",
+                "schema_translate_map": {"foo": "bar"},
+            }
+            stmt = stmt.execution_options(**opts)
+            result = sess.scalars(stmt)
+        else:
+            result = ()
+            opts = None
+            opt_location.fail()
+
+        for u1 in result:
+            u1.addresses
+
+        for elem in catch_opts:
+            eq_(elem, opts)
+
 
 class FlushWarningsTest(fixtures.MappedTest):
     run_setup_mappers = "each"