]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
use an ORM compile state for all statements with any ORM entities anywhere
authorMike Bayer <mike_mp@zzzcomputing.com>
Fri, 19 May 2023 13:35:30 +0000 (09:35 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 19 May 2023 14:33:08 +0000 (10:33 -0400)
As more projects are using new-style "2.0" ORM querying, it's becoming
apparent that the conditional nature of "autoflush", being based on whether
or not the given statement refers to ORM entities, is becoming more of a
key behavior. Up until now, the "ORM" flag for a statement has been loosely
based around whether or not the statement returns rows that correspond to
ORM entities or columns; the original purpose of the "ORM" flag was to
enable ORM-entity fetching rules which apply post-processing to Core result
sets as well as ORM loader strategies to the statement.  For statements
that don't build on rows that contain ORM entities, the "ORM" flag was
considered to be mostly unnecessary.

It still may be the case that "autoflush" would be better taking effect for
*all* usage of :meth:`_orm.Session.execute` and related methods, even for
purely Core SQL constructs. However, this still could impact legacy cases
where this is not expected and may be more of a 2.1 thing. For now however,
the rules for the "ORM-flag" have been opened up so that a statement that
includes ORM entities or attributes anywhere within, including in the WHERE
/ ORDER BY / GROUP BY clause alone, within scalar subqueries, etc. will
enable this flag.  This will cause "autoflush" to occur for such statements
and also be visible via the :attr:`_orm.ORMExecuteState.is_orm_statement`
event-level attribute.

Fixes: #9805
Change-Id: Idcabefc8fedd14edcf603b90e26e5982c849a1fc

doc/build/changelog/unreleased_20/9805.rst [new file with mode: 0644]
lib/sqlalchemy/orm/bulk_persistence.py
lib/sqlalchemy/orm/context.py
lib/sqlalchemy/sql/coercions.py
lib/sqlalchemy/sql/dml.py
lib/sqlalchemy/sql/elements.py
lib/sqlalchemy/sql/selectable.py
test/orm/test_bind.py
test/orm/test_core_compilation.py
test/orm/test_events.py

diff --git a/doc/build/changelog/unreleased_20/9805.rst b/doc/build/changelog/unreleased_20/9805.rst
new file mode 100644 (file)
index 0000000..a0b3ce9
--- /dev/null
@@ -0,0 +1,27 @@
+.. change::
+    :tags: bug, orm
+    :tickets: 9805
+
+    As more projects are using new-style "2.0" ORM querying, it's becoming
+    apparent that the conditional nature of "autoflush", being based on whether
+    or not the given statement refers to ORM entities, is becoming more of a
+    key behavior. Up until now, the "ORM" flag for a statement has been loosely
+    based around whether or not the statement returns rows that correspond to
+    ORM entities or columns; the original purpose of the "ORM" flag was to
+    enable ORM-entity fetching rules which apply post-processing to Core result
+    sets as well as ORM loader strategies to the statement.  For statements
+    that don't build on rows that contain ORM entities, the "ORM" flag was
+    considered to be mostly unnecessary.
+
+    It still may be the case that "autoflush" would be better taking effect for
+    *all* usage of :meth:`_orm.Session.execute` and related methods, even for
+    purely Core SQL constructs. However, this still could impact legacy cases
+    where this is not expected and may be more of a 2.1 thing. For now however,
+    the rules for the "ORM-flag" have been opened up so that a statement that
+    includes ORM entities or attributes anywhere within, including in the WHERE
+    / ORDER BY / GROUP BY clause alone, within scalar subqueries, etc. will
+    enable this flag.  This will cause "autoflush" to occur for such statements
+    and also be visible via the :attr:`_orm.ORMExecuteState.is_orm_statement`
+    event-level attribute.
+
+
index b75285ebdea36887b69c3f43bc9a91e8ced5342c..58dfd5e32af1072f3adedb813749610cc0e66e73 100644 (file)
@@ -665,7 +665,9 @@ class BulkUDCompileState(ORMDMLState):
 
         update_options += {"_subject_mapper": plugin_subject.mapper}
 
-        if not isinstance(params, list):
+        if "parententity" not in statement.table._annotations:
+            update_options += {"_dml_strategy": "core_only"}
+        elif not isinstance(params, list):
             if update_options._dml_strategy == "auto":
                 update_options += {"_dml_strategy": "orm"}
             elif update_options._dml_strategy == "bulk":
@@ -1401,6 +1403,12 @@ class BulkORMUpdate(BulkUDCompileState, UpdateDMLState):
 
         if toplevel and dml_strategy == "bulk":
             self._setup_for_bulk_update(statement, compiler)
+        elif (
+            dml_strategy == "core_only"
+            or dml_strategy == "unspecified"
+            and "parententity" not in statement.table._annotations
+        ):
+            UpdateDMLState.__init__(self, statement, compiler, **kw)
         elif not toplevel or dml_strategy in ("orm", "unspecified"):
             self._setup_for_orm_update(statement, compiler)
 
@@ -1555,10 +1563,15 @@ class BulkORMUpdate(BulkUDCompileState, UpdateDMLState):
             "_sa_orm_update_options", cls.default_update_options
         )
 
-        if update_options._dml_strategy not in ("orm", "auto", "bulk"):
+        if update_options._dml_strategy not in (
+            "orm",
+            "auto",
+            "bulk",
+            "core_only",
+        ):
             raise sa_exc.ArgumentError(
                 "Valid strategies for ORM UPDATE strategy "
-                "are 'orm', 'auto', 'bulk'"
+                "are 'orm', 'auto', 'bulk', 'core_only'"
             )
 
         result: _result.Result[Any]
@@ -1822,6 +1835,18 @@ class BulkORMDelete(BulkUDCompileState, DeleteDMLState):
     def create_for_statement(cls, statement, compiler, **kw):
         self = cls.__new__(cls)
 
+        dml_strategy = statement._annotations.get(
+            "dml_strategy", "unspecified"
+        )
+
+        if (
+            dml_strategy == "core_only"
+            or dml_strategy == "unspecified"
+            and "parententity" not in statement.table._annotations
+        ):
+            DeleteDMLState.__init__(self, statement, compiler, **kw)
+            return self
+
         toplevel = not compiler.stack
 
         orm_level_statement = statement
@@ -1919,12 +1944,10 @@ class BulkORMDelete(BulkUDCompileState, DeleteDMLState):
                 "session.connection().execute(stmt, parameters)"
             )
 
-        if update_options._dml_strategy not in (
-            "orm",
-            "auto",
-        ):
+        if update_options._dml_strategy not in ("orm", "auto", "core_only"):
             raise sa_exc.ArgumentError(
-                "Valid strategies for ORM DELETE strategy are 'orm', 'auto'"
+                "Valid strategies for ORM DELETE strategy are 'orm', 'auto', "
+                "'core_only'"
             )
 
         return super().orm_execute_statement(
index e778c4840852ee38b24b2259003e52c21371873d..397e90fadb815527687066d5b2a03127bdec22f1 100644 (file)
@@ -59,6 +59,7 @@ from ..sql.base import Options
 from ..sql.dml import UpdateBase
 from ..sql.elements import GroupedElement
 from ..sql.elements import TextClause
+from ..sql.selectable import CompoundSelectState
 from ..sql.selectable import LABEL_STYLE_DISAMBIGUATE_ONLY
 from ..sql.selectable import LABEL_STYLE_NONE
 from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL
@@ -314,6 +315,52 @@ class AbstractORMCompileState(CompileState):
         raise NotImplementedError()
 
 
+class AutoflushOnlyORMCompileState(AbstractORMCompileState):
+    """ORM compile state that is a passthrough, except for autoflush."""
+
+    @classmethod
+    def orm_pre_session_exec(
+        cls,
+        session,
+        statement,
+        params,
+        execution_options,
+        bind_arguments,
+        is_pre_event,
+    ):
+
+        # consume result-level load_options.  These may have been set up
+        # in an ORMExecuteState hook
+        (
+            load_options,
+            execution_options,
+        ) = QueryContext.default_load_options.from_execution_options(
+            "_sa_orm_load_options",
+            {
+                "autoflush",
+            },
+            execution_options,
+            statement._execution_options,
+        )
+
+        if not is_pre_event and load_options._autoflush:
+            session._autoflush()
+
+        return statement, execution_options
+
+    @classmethod
+    def orm_setup_cursor_result(
+        cls,
+        session,
+        statement,
+        params,
+        execution_options,
+        bind_arguments,
+        result,
+    ):
+        return result
+
+
 class ORMCompileState(AbstractORMCompileState):
     class default_compile_options(CacheableOptions):
         _cache_key_traversal = [
@@ -914,6 +961,13 @@ class FromStatement(GroupedElement, Generative, TypedReturnsRows[_TP]):
         return self.element._inline if is_insert_update(self.element) else None
 
 
+@sql.base.CompileState.plugin_for("orm", "compound_select")
+class CompoundSelectCompileState(
+    AutoflushOnlyORMCompileState, CompoundSelectState
+):
+    pass
+
+
 @sql.base.CompileState.plugin_for("orm", "select")
 class ORMSelectCompileState(ORMCompileState, SelectState):
     _already_joined_edges = ()
index 3b187e49d4d965fd6fffc583f14888c415d5e171..78954a082ad7d602e664b73dedb961a465912036 100644 (file)
@@ -424,9 +424,8 @@ def expect(
         if typing.TYPE_CHECKING:
             assert isinstance(resolved, (SQLCoreOperations, ClauseElement))
 
-        if (
-            not apply_propagate_attrs._propagate_attrs
-            and resolved._propagate_attrs
+        if not apply_propagate_attrs._propagate_attrs and getattr(
+            resolved, "_propagate_attrs", None
         ):
             apply_propagate_attrs._propagate_attrs = resolved._propagate_attrs
 
index 91106164069a960d9afb4981138dbe847df969d6..987910f0ca67b67f8d47e8d81f913c48b8c5726b 100644 (file)
@@ -1446,7 +1446,7 @@ class DMLWhereBase:
 
         for criterion in whereclause:
             where_criteria: ColumnElement[Any] = coercions.expect(
-                roles.WhereHavingRole, criterion
+                roles.WhereHavingRole, criterion, apply_propagate_attrs=self
             )
             self._where_criteria += (where_criteria,)
         return self
index 884e2b90fedbfe6652784b335fa19ab16c240bb5..54d876a8016001b946581fd3dd2a870a22270863 100644 (file)
@@ -4033,6 +4033,7 @@ class Grouping(GroupedElement, ColumnElement[_T]):
 
         # nulltype assignment issue
         self.type = getattr(element, "type", type_api.NULLTYPE)  # type: ignore
+        self._propagate_attrs = element._propagate_attrs
 
     def _with_binary_element_type(self, type_):
         return self.__class__(self.element._with_binary_element_type(type_))
index 19d4641808ba0e35e812322bb9238838881136cf..c7df0b8016a024b8f4c60f3133e22b7303584ce9 100644 (file)
@@ -4179,7 +4179,9 @@ class GenerativeSelect(SelectBase, Generative):
             self._order_by_clauses = ()
         elif __first is not _NoArg.NO_ARG:
             self._order_by_clauses += tuple(
-                coercions.expect(roles.OrderByRole, clause)
+                coercions.expect(
+                    roles.OrderByRole, clause, apply_propagate_attrs=self
+                )
                 for clause in (__first,) + clauses
             )
         return self
@@ -4220,7 +4222,9 @@ class GenerativeSelect(SelectBase, Generative):
             self._group_by_clauses = ()
         elif __first is not _NoArg.NO_ARG:
             self._group_by_clauses += tuple(
-                coercions.expect(roles.GroupByRole, clause)
+                coercions.expect(
+                    roles.GroupByRole, clause, apply_propagate_attrs=self
+                )
                 for clause in (__first,) + clauses
             )
         return self
@@ -4299,9 +4303,9 @@ class CompoundSelect(HasCompileState, GenerativeSelect, ExecutableReturnsRows):
     ):
         self.keyword = keyword
         self.selects = [
-            coercions.expect(roles.CompoundElementRole, s).self_group(
-                against=self
-            )
+            coercions.expect(
+                roles.CompoundElementRole, s, apply_propagate_attrs=self
+            ).self_group(against=self)
             for s in selects
         ]
 
@@ -5966,7 +5970,7 @@ class Select(
 
         for criterion in whereclause:
             where_criteria: ColumnElement[Any] = coercions.expect(
-                roles.WhereHavingRole, criterion
+                roles.WhereHavingRole, criterion, apply_propagate_attrs=self
             )
             self._where_criteria += (where_criteria,)
         return self
@@ -5981,7 +5985,7 @@ class Select(
 
         for criterion in having:
             having_criteria = coercions.expect(
-                roles.WhereHavingRole, criterion
+                roles.WhereHavingRole, criterion, apply_propagate_attrs=self
             )
             self._having_criteria += (having_criteria,)
         return self
@@ -6002,7 +6006,8 @@ class Select(
         if expr:
             self._distinct = True
             self._distinct_on = self._distinct_on + tuple(
-                coercions.expect(roles.ByOfRole, e) for e in expr
+                coercions.expect(roles.ByOfRole, e, apply_propagate_attrs=self)
+                for e in expr
             )
         else:
             self._distinct = True
@@ -6474,6 +6479,7 @@ class ScalarSelect(
     def __init__(self, element: SelectBase) -> None:
         self.element = element
         self.type = element._scalar_type()
+        self._propagate_attrs = element._propagate_attrs
 
     def __getattr__(self, attr: str) -> Any:
         return getattr(self.element, attr)
index 13958ec91c9da467070082276e5539b52e778932..ff3d010707e0f6c0d06086966fc3e7ae549e2c37 100644 (file)
@@ -364,8 +364,8 @@ class BindIntegrationTest(_fixtures.FixtureTest):
         ),
         (
             lambda User: select(1).where(User.name == "ed"),
-            # no mapper for this one because the plugin is not "orm"
-            lambda User: {"clause": mock.ANY},
+            # changed by #9805
+            lambda User: {"clause": mock.ANY, "mapper": inspect(User)},
             "e1",
         ),
         (
index 8b28de591d3d9cce9082a7fb546d7bf192a45328..f24b97c7d497d41003d5b9f9e6a9b76ad3035811 100644 (file)
@@ -2,6 +2,7 @@ from sqlalchemy import bindparam
 from sqlalchemy import Column
 from sqlalchemy import delete
 from sqlalchemy import exc
+from sqlalchemy import exists
 from sqlalchemy import ForeignKey
 from sqlalchemy import func
 from sqlalchemy import insert
@@ -15,6 +16,7 @@ from sqlalchemy import select
 from sqlalchemy import String
 from sqlalchemy import testing
 from sqlalchemy import text
+from sqlalchemy import true
 from sqlalchemy import union
 from sqlalchemy import update
 from sqlalchemy import util
@@ -26,6 +28,7 @@ from sqlalchemy.orm import join as orm_join
 from sqlalchemy.orm import joinedload
 from sqlalchemy.orm import query_expression
 from sqlalchemy.orm import relationship
+from sqlalchemy.orm import Session
 from sqlalchemy.orm import undefer
 from sqlalchemy.orm import with_expression
 from sqlalchemy.orm import with_loader_criteria
@@ -40,6 +43,7 @@ from sqlalchemy.testing import AssertsCompiledSQL
 from sqlalchemy.testing import eq_
 from sqlalchemy.testing import fixtures
 from sqlalchemy.testing import is_
+from sqlalchemy.testing import mock
 from sqlalchemy.testing import Variation
 from sqlalchemy.testing.fixtures import fixture_session
 from sqlalchemy.testing.util import resolve_lambda
@@ -360,6 +364,129 @@ class SelectableTest(QueryTest, AssertsCompiledSQL):
         )
 
 
+class PropagateAttrsTest(QueryTest):
+    def propagate_cases():
+        return testing.combinations(
+            (lambda: select(1), False),
+            (lambda User: select(func.count(User.id)), True),
+            (
+                lambda User: select(1).select_from(select(User).subquery()),
+                True,
+            ),
+            (
+                lambda User: select(
+                    select(User.id).where(User.id == 5).scalar_subquery()
+                ),
+                True,
+            ),
+            (
+                lambda User: select(
+                    select(User.id).where(User.id == 5).label("x")
+                ),
+                True,
+            ),
+            (lambda User: select(1).select_from(User), True),
+            (lambda User: select(1).where(exists(User.id)), True),
+            (lambda User: select(1).where(~exists(User.id)), True),
+            (
+                # changed as part of #9805
+                lambda User: select(1).where(User.id == 1),
+                True,
+            ),
+            (
+                # changed as part of #9805
+                lambda User, user_table: select(func.count(1))
+                .select_from(user_table)
+                .group_by(user_table.c.id)
+                .having(User.id == 1),
+                True,
+            ),
+            (
+                # changed as part of #9805
+                lambda User, user_table: select(1)
+                .select_from(user_table)
+                .order_by(User.id),
+                True,
+            ),
+            (
+                # changed as part of #9805
+                lambda User, user_table: select(1)
+                .select_from(user_table)
+                .group_by(User.id),
+                True,
+            ),
+            (
+                lambda User, user_table: select(user_table).join(
+                    aliased(User), true()
+                ),
+                True,
+            ),
+            (
+                # changed as part of #9805
+                lambda User, user_table: select(1)
+                .distinct(User.id)
+                .select_from(user_table),
+                True,
+                testing.requires.supports_distinct_on,
+            ),
+            (lambda user_table: select(user_table), False),
+            (lambda User: select(User), True),
+            (lambda User: union(select(User), select(User)), True),
+            (
+                lambda User: select(1).select_from(
+                    union(select(User), select(User)).subquery()
+                ),
+                True,
+            ),
+            (lambda User: select(User.id), True),
+            # these are meaningless, correlate by itself has no effect
+            (lambda User: select(1).correlate(User), False),
+            (lambda User: select(1).correlate_except(User), False),
+            (lambda User: delete(User).where(User.id > 20), True),
+            (
+                lambda User, user_table: delete(user_table).where(
+                    User.id > 20
+                ),
+                True,
+            ),
+            (lambda User: update(User).values(name="x"), True),
+            (
+                lambda User, user_table: update(user_table)
+                .values(name="x")
+                .where(User.id > 20),
+                True,
+            ),
+            (lambda User: insert(User).values(name="x"), True),
+        )
+
+    @propagate_cases()
+    def test_propagate_attr_yesno(self, test_case, expected):
+        User = self.classes.User
+        user_table = self.tables.users
+
+        stmt = resolve_lambda(test_case, User=User, user_table=user_table)
+
+        eq_(bool(stmt._propagate_attrs), expected)
+
+    @propagate_cases()
+    def test_autoflushes(self, test_case, expected):
+        User = self.classes.User
+        user_table = self.tables.users
+
+        stmt = resolve_lambda(test_case, User=User, user_table=user_table)
+
+        with Session(testing.db) as s:
+
+            with mock.patch.object(s, "_autoflush", wrap=True) as before_flush:
+                r = s.execute(stmt)
+                r.close()
+
+        if expected:
+            eq_(before_flush.mock_calls, [mock.call()])
+        else:
+            eq_(before_flush.mock_calls, [])
+
+
 class DMLTest(QueryTest, AssertsCompiledSQL):
     __dialect__ = "default"
 
index 07d27451d9280f5ec682f51ef499f819df448626..5b575fd7cffe5bf1075c9fa13d5be4b321db0389 100644 (file)
@@ -403,7 +403,10 @@ class ORMExecuteTest(RemoveORMEventsGlobally, _fixtures.FixtureTest):
 
     @testing.combinations(
         (lambda: select(1), True),
-        (lambda User: select(User).union(select(User)), True),
+        (
+            lambda user_table: select(user_table).union(select(user_table)),
+            True,
+        ),
         (lambda: text("select * from users"), False),
     )
     def test_non_orm_statements(self, stmt, is_select):
@@ -411,8 +414,8 @@ class ORMExecuteTest(RemoveORMEventsGlobally, _fixtures.FixtureTest):
 
         canary = self._flag_fixture(sess)
 
-        User, Address = self.classes("User", "Address")
-        stmt = testing.resolve_lambda(stmt, User=User)
+        user_table = self.tables.users
+        stmt = testing.resolve_lambda(stmt, user_table=user_table)
         sess.execute(stmt).all()
 
         eq_(