From: Mike Bayer Date: Fri, 19 May 2023 13:35:30 +0000 (-0400) Subject: use an ORM compile state for all statements with any ORM entities anywhere X-Git-Tag: rel_2_0_15~3^2 X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=a2e27eb4d718f732494ff008aad1d0cd56c4ad88;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git use an ORM compile state for all statements with any ORM entities anywhere As more projects are using new-style "2.0" ORM querying, it's becoming apparent that the conditional nature of "autoflush", being based on whether or not the given statement refers to ORM entities, is becoming more of a key behavior. Up until now, the "ORM" flag for a statement has been loosely based around whether or not the statement returns rows that correspond to ORM entities or columns; the original purpose of the "ORM" flag was to enable ORM-entity fetching rules which apply post-processing to Core result sets as well as ORM loader strategies to the statement. For statements that don't build on rows that contain ORM entities, the "ORM" flag was considered to be mostly unnecessary. It still may be the case that "autoflush" would be better taking effect for *all* usage of :meth:`_orm.Session.execute` and related methods, even for purely Core SQL constructs. However, this still could impact legacy cases where this is not expected and may be more of a 2.1 thing. For now however, the rules for the "ORM-flag" have been opened up so that a statement that includes ORM entities or attributes anywhere within, including in the WHERE / ORDER BY / GROUP BY clause alone, within scalar subqueries, etc. will enable this flag. This will cause "autoflush" to occur for such statements and also be visible via the :attr:`_orm.ORMExecuteState.is_orm_statement` event-level attribute. Fixes: #9805 Change-Id: Idcabefc8fedd14edcf603b90e26e5982c849a1fc --- diff --git a/doc/build/changelog/unreleased_20/9805.rst b/doc/build/changelog/unreleased_20/9805.rst new file mode 100644 index 0000000000..a0b3ce99b1 --- /dev/null +++ b/doc/build/changelog/unreleased_20/9805.rst @@ -0,0 +1,27 @@ +.. change:: + :tags: bug, orm + :tickets: 9805 + + As more projects are using new-style "2.0" ORM querying, it's becoming + apparent that the conditional nature of "autoflush", being based on whether + or not the given statement refers to ORM entities, is becoming more of a + key behavior. Up until now, the "ORM" flag for a statement has been loosely + based around whether or not the statement returns rows that correspond to + ORM entities or columns; the original purpose of the "ORM" flag was to + enable ORM-entity fetching rules which apply post-processing to Core result + sets as well as ORM loader strategies to the statement. For statements + that don't build on rows that contain ORM entities, the "ORM" flag was + considered to be mostly unnecessary. + + It still may be the case that "autoflush" would be better taking effect for + *all* usage of :meth:`_orm.Session.execute` and related methods, even for + purely Core SQL constructs. However, this still could impact legacy cases + where this is not expected and may be more of a 2.1 thing. For now however, + the rules for the "ORM-flag" have been opened up so that a statement that + includes ORM entities or attributes anywhere within, including in the WHERE + / ORDER BY / GROUP BY clause alone, within scalar subqueries, etc. will + enable this flag. This will cause "autoflush" to occur for such statements + and also be visible via the :attr:`_orm.ORMExecuteState.is_orm_statement` + event-level attribute. + + diff --git a/lib/sqlalchemy/orm/bulk_persistence.py b/lib/sqlalchemy/orm/bulk_persistence.py index b75285ebde..58dfd5e32a 100644 --- a/lib/sqlalchemy/orm/bulk_persistence.py +++ b/lib/sqlalchemy/orm/bulk_persistence.py @@ -665,7 +665,9 @@ class BulkUDCompileState(ORMDMLState): update_options += {"_subject_mapper": plugin_subject.mapper} - if not isinstance(params, list): + if "parententity" not in statement.table._annotations: + update_options += {"_dml_strategy": "core_only"} + elif not isinstance(params, list): if update_options._dml_strategy == "auto": update_options += {"_dml_strategy": "orm"} elif update_options._dml_strategy == "bulk": @@ -1401,6 +1403,12 @@ class BulkORMUpdate(BulkUDCompileState, UpdateDMLState): if toplevel and dml_strategy == "bulk": self._setup_for_bulk_update(statement, compiler) + elif ( + dml_strategy == "core_only" + or dml_strategy == "unspecified" + and "parententity" not in statement.table._annotations + ): + UpdateDMLState.__init__(self, statement, compiler, **kw) elif not toplevel or dml_strategy in ("orm", "unspecified"): self._setup_for_orm_update(statement, compiler) @@ -1555,10 +1563,15 @@ class BulkORMUpdate(BulkUDCompileState, UpdateDMLState): "_sa_orm_update_options", cls.default_update_options ) - if update_options._dml_strategy not in ("orm", "auto", "bulk"): + if update_options._dml_strategy not in ( + "orm", + "auto", + "bulk", + "core_only", + ): raise sa_exc.ArgumentError( "Valid strategies for ORM UPDATE strategy " - "are 'orm', 'auto', 'bulk'" + "are 'orm', 'auto', 'bulk', 'core_only'" ) result: _result.Result[Any] @@ -1822,6 +1835,18 @@ class BulkORMDelete(BulkUDCompileState, DeleteDMLState): def create_for_statement(cls, statement, compiler, **kw): self = cls.__new__(cls) + dml_strategy = statement._annotations.get( + "dml_strategy", "unspecified" + ) + + if ( + dml_strategy == "core_only" + or dml_strategy == "unspecified" + and "parententity" not in statement.table._annotations + ): + DeleteDMLState.__init__(self, statement, compiler, **kw) + return self + toplevel = not compiler.stack orm_level_statement = statement @@ -1919,12 +1944,10 @@ class BulkORMDelete(BulkUDCompileState, DeleteDMLState): "session.connection().execute(stmt, parameters)" ) - if update_options._dml_strategy not in ( - "orm", - "auto", - ): + if update_options._dml_strategy not in ("orm", "auto", "core_only"): raise sa_exc.ArgumentError( - "Valid strategies for ORM DELETE strategy are 'orm', 'auto'" + "Valid strategies for ORM DELETE strategy are 'orm', 'auto', " + "'core_only'" ) return super().orm_execute_statement( diff --git a/lib/sqlalchemy/orm/context.py b/lib/sqlalchemy/orm/context.py index e778c48408..397e90fadb 100644 --- a/lib/sqlalchemy/orm/context.py +++ b/lib/sqlalchemy/orm/context.py @@ -59,6 +59,7 @@ from ..sql.base import Options from ..sql.dml import UpdateBase from ..sql.elements import GroupedElement from ..sql.elements import TextClause +from ..sql.selectable import CompoundSelectState from ..sql.selectable import LABEL_STYLE_DISAMBIGUATE_ONLY from ..sql.selectable import LABEL_STYLE_NONE from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL @@ -314,6 +315,52 @@ class AbstractORMCompileState(CompileState): raise NotImplementedError() +class AutoflushOnlyORMCompileState(AbstractORMCompileState): + """ORM compile state that is a passthrough, except for autoflush.""" + + @classmethod + def orm_pre_session_exec( + cls, + session, + statement, + params, + execution_options, + bind_arguments, + is_pre_event, + ): + + # consume result-level load_options. These may have been set up + # in an ORMExecuteState hook + ( + load_options, + execution_options, + ) = QueryContext.default_load_options.from_execution_options( + "_sa_orm_load_options", + { + "autoflush", + }, + execution_options, + statement._execution_options, + ) + + if not is_pre_event and load_options._autoflush: + session._autoflush() + + return statement, execution_options + + @classmethod + def orm_setup_cursor_result( + cls, + session, + statement, + params, + execution_options, + bind_arguments, + result, + ): + return result + + class ORMCompileState(AbstractORMCompileState): class default_compile_options(CacheableOptions): _cache_key_traversal = [ @@ -914,6 +961,13 @@ class FromStatement(GroupedElement, Generative, TypedReturnsRows[_TP]): return self.element._inline if is_insert_update(self.element) else None +@sql.base.CompileState.plugin_for("orm", "compound_select") +class CompoundSelectCompileState( + AutoflushOnlyORMCompileState, CompoundSelectState +): + pass + + @sql.base.CompileState.plugin_for("orm", "select") class ORMSelectCompileState(ORMCompileState, SelectState): _already_joined_edges = () diff --git a/lib/sqlalchemy/sql/coercions.py b/lib/sqlalchemy/sql/coercions.py index 3b187e49d4..78954a082a 100644 --- a/lib/sqlalchemy/sql/coercions.py +++ b/lib/sqlalchemy/sql/coercions.py @@ -424,9 +424,8 @@ def expect( if typing.TYPE_CHECKING: assert isinstance(resolved, (SQLCoreOperations, ClauseElement)) - if ( - not apply_propagate_attrs._propagate_attrs - and resolved._propagate_attrs + if not apply_propagate_attrs._propagate_attrs and getattr( + resolved, "_propagate_attrs", None ): apply_propagate_attrs._propagate_attrs = resolved._propagate_attrs diff --git a/lib/sqlalchemy/sql/dml.py b/lib/sqlalchemy/sql/dml.py index 9110616406..987910f0ca 100644 --- a/lib/sqlalchemy/sql/dml.py +++ b/lib/sqlalchemy/sql/dml.py @@ -1446,7 +1446,7 @@ class DMLWhereBase: for criterion in whereclause: where_criteria: ColumnElement[Any] = coercions.expect( - roles.WhereHavingRole, criterion + roles.WhereHavingRole, criterion, apply_propagate_attrs=self ) self._where_criteria += (where_criteria,) return self diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index 884e2b90fe..54d876a801 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -4033,6 +4033,7 @@ class Grouping(GroupedElement, ColumnElement[_T]): # nulltype assignment issue self.type = getattr(element, "type", type_api.NULLTYPE) # type: ignore + self._propagate_attrs = element._propagate_attrs def _with_binary_element_type(self, type_): return self.__class__(self.element._with_binary_element_type(type_)) diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py index 19d4641808..c7df0b8016 100644 --- a/lib/sqlalchemy/sql/selectable.py +++ b/lib/sqlalchemy/sql/selectable.py @@ -4179,7 +4179,9 @@ class GenerativeSelect(SelectBase, Generative): self._order_by_clauses = () elif __first is not _NoArg.NO_ARG: self._order_by_clauses += tuple( - coercions.expect(roles.OrderByRole, clause) + coercions.expect( + roles.OrderByRole, clause, apply_propagate_attrs=self + ) for clause in (__first,) + clauses ) return self @@ -4220,7 +4222,9 @@ class GenerativeSelect(SelectBase, Generative): self._group_by_clauses = () elif __first is not _NoArg.NO_ARG: self._group_by_clauses += tuple( - coercions.expect(roles.GroupByRole, clause) + coercions.expect( + roles.GroupByRole, clause, apply_propagate_attrs=self + ) for clause in (__first,) + clauses ) return self @@ -4299,9 +4303,9 @@ class CompoundSelect(HasCompileState, GenerativeSelect, ExecutableReturnsRows): ): self.keyword = keyword self.selects = [ - coercions.expect(roles.CompoundElementRole, s).self_group( - against=self - ) + coercions.expect( + roles.CompoundElementRole, s, apply_propagate_attrs=self + ).self_group(against=self) for s in selects ] @@ -5966,7 +5970,7 @@ class Select( for criterion in whereclause: where_criteria: ColumnElement[Any] = coercions.expect( - roles.WhereHavingRole, criterion + roles.WhereHavingRole, criterion, apply_propagate_attrs=self ) self._where_criteria += (where_criteria,) return self @@ -5981,7 +5985,7 @@ class Select( for criterion in having: having_criteria = coercions.expect( - roles.WhereHavingRole, criterion + roles.WhereHavingRole, criterion, apply_propagate_attrs=self ) self._having_criteria += (having_criteria,) return self @@ -6002,7 +6006,8 @@ class Select( if expr: self._distinct = True self._distinct_on = self._distinct_on + tuple( - coercions.expect(roles.ByOfRole, e) for e in expr + coercions.expect(roles.ByOfRole, e, apply_propagate_attrs=self) + for e in expr ) else: self._distinct = True @@ -6474,6 +6479,7 @@ class ScalarSelect( def __init__(self, element: SelectBase) -> None: self.element = element self.type = element._scalar_type() + self._propagate_attrs = element._propagate_attrs def __getattr__(self, attr: str) -> Any: return getattr(self.element, attr) diff --git a/test/orm/test_bind.py b/test/orm/test_bind.py index 13958ec91c..ff3d010707 100644 --- a/test/orm/test_bind.py +++ b/test/orm/test_bind.py @@ -364,8 +364,8 @@ class BindIntegrationTest(_fixtures.FixtureTest): ), ( lambda User: select(1).where(User.name == "ed"), - # no mapper for this one because the plugin is not "orm" - lambda User: {"clause": mock.ANY}, + # changed by #9805 + lambda User: {"clause": mock.ANY, "mapper": inspect(User)}, "e1", ), ( diff --git a/test/orm/test_core_compilation.py b/test/orm/test_core_compilation.py index 8b28de591d..f24b97c7d4 100644 --- a/test/orm/test_core_compilation.py +++ b/test/orm/test_core_compilation.py @@ -2,6 +2,7 @@ from sqlalchemy import bindparam from sqlalchemy import Column from sqlalchemy import delete from sqlalchemy import exc +from sqlalchemy import exists from sqlalchemy import ForeignKey from sqlalchemy import func from sqlalchemy import insert @@ -15,6 +16,7 @@ from sqlalchemy import select from sqlalchemy import String from sqlalchemy import testing from sqlalchemy import text +from sqlalchemy import true from sqlalchemy import union from sqlalchemy import update from sqlalchemy import util @@ -26,6 +28,7 @@ from sqlalchemy.orm import join as orm_join from sqlalchemy.orm import joinedload from sqlalchemy.orm import query_expression from sqlalchemy.orm import relationship +from sqlalchemy.orm import Session from sqlalchemy.orm import undefer from sqlalchemy.orm import with_expression from sqlalchemy.orm import with_loader_criteria @@ -40,6 +43,7 @@ from sqlalchemy.testing import AssertsCompiledSQL from sqlalchemy.testing import eq_ from sqlalchemy.testing import fixtures from sqlalchemy.testing import is_ +from sqlalchemy.testing import mock from sqlalchemy.testing import Variation from sqlalchemy.testing.fixtures import fixture_session from sqlalchemy.testing.util import resolve_lambda @@ -360,6 +364,129 @@ class SelectableTest(QueryTest, AssertsCompiledSQL): ) +class PropagateAttrsTest(QueryTest): + def propagate_cases(): + return testing.combinations( + (lambda: select(1), False), + (lambda User: select(func.count(User.id)), True), + ( + lambda User: select(1).select_from(select(User).subquery()), + True, + ), + ( + lambda User: select( + select(User.id).where(User.id == 5).scalar_subquery() + ), + True, + ), + ( + lambda User: select( + select(User.id).where(User.id == 5).label("x") + ), + True, + ), + (lambda User: select(1).select_from(User), True), + (lambda User: select(1).where(exists(User.id)), True), + (lambda User: select(1).where(~exists(User.id)), True), + ( + # changed as part of #9805 + lambda User: select(1).where(User.id == 1), + True, + ), + ( + # changed as part of #9805 + lambda User, user_table: select(func.count(1)) + .select_from(user_table) + .group_by(user_table.c.id) + .having(User.id == 1), + True, + ), + ( + # changed as part of #9805 + lambda User, user_table: select(1) + .select_from(user_table) + .order_by(User.id), + True, + ), + ( + # changed as part of #9805 + lambda User, user_table: select(1) + .select_from(user_table) + .group_by(User.id), + True, + ), + ( + lambda User, user_table: select(user_table).join( + aliased(User), true() + ), + True, + ), + ( + # changed as part of #9805 + lambda User, user_table: select(1) + .distinct(User.id) + .select_from(user_table), + True, + testing.requires.supports_distinct_on, + ), + (lambda user_table: select(user_table), False), + (lambda User: select(User), True), + (lambda User: union(select(User), select(User)), True), + ( + lambda User: select(1).select_from( + union(select(User), select(User)).subquery() + ), + True, + ), + (lambda User: select(User.id), True), + # these are meaningless, correlate by itself has no effect + (lambda User: select(1).correlate(User), False), + (lambda User: select(1).correlate_except(User), False), + (lambda User: delete(User).where(User.id > 20), True), + ( + lambda User, user_table: delete(user_table).where( + User.id > 20 + ), + True, + ), + (lambda User: update(User).values(name="x"), True), + ( + lambda User, user_table: update(user_table) + .values(name="x") + .where(User.id > 20), + True, + ), + (lambda User: insert(User).values(name="x"), True), + ) + + @propagate_cases() + def test_propagate_attr_yesno(self, test_case, expected): + User = self.classes.User + user_table = self.tables.users + + stmt = resolve_lambda(test_case, User=User, user_table=user_table) + + eq_(bool(stmt._propagate_attrs), expected) + + @propagate_cases() + def test_autoflushes(self, test_case, expected): + User = self.classes.User + user_table = self.tables.users + + stmt = resolve_lambda(test_case, User=User, user_table=user_table) + + with Session(testing.db) as s: + + with mock.patch.object(s, "_autoflush", wrap=True) as before_flush: + r = s.execute(stmt) + r.close() + + if expected: + eq_(before_flush.mock_calls, [mock.call()]) + else: + eq_(before_flush.mock_calls, []) + + class DMLTest(QueryTest, AssertsCompiledSQL): __dialect__ = "default" diff --git a/test/orm/test_events.py b/test/orm/test_events.py index 07d27451d9..5b575fd7cf 100644 --- a/test/orm/test_events.py +++ b/test/orm/test_events.py @@ -403,7 +403,10 @@ class ORMExecuteTest(RemoveORMEventsGlobally, _fixtures.FixtureTest): @testing.combinations( (lambda: select(1), True), - (lambda User: select(User).union(select(User)), True), + ( + lambda user_table: select(user_table).union(select(user_table)), + True, + ), (lambda: text("select * from users"), False), ) def test_non_orm_statements(self, stmt, is_select): @@ -411,8 +414,8 @@ class ORMExecuteTest(RemoveORMEventsGlobally, _fixtures.FixtureTest): canary = self._flag_fixture(sess) - User, Address = self.classes("User", "Address") - stmt = testing.resolve_lambda(stmt, User=User) + user_table = self.tables.users + stmt = testing.resolve_lambda(stmt, user_table=user_table) sess.execute(stmt).all() eq_(