From: Mike Bayer Date: Sun, 30 Aug 2020 22:13:36 +0000 (-0400) Subject: Support extra / single inh criteria with ORM update/delete X-Git-Tag: rel_1_4_0b1~139 X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=575b6dded9a25fca693f0aa7f6d7c6e735490460;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Support extra / single inh criteria with ORM update/delete The ORM bulk update and delete operations, historically available via the :meth:`_orm.Query.update` and :meth:`_orm.Query.delete` methods as well as via the :class:`_dml.Update` and :class:`_dml.Delete` constructs for :term:`2.0 style` execution, will now automatically accommodate for the additional WHERE criteria needed for a single-table inheritance discrminiator. Joined-table inheritance is still not directly supported. The new :func:`_orm.with_loader_criteria` construct is also supported for all mappings with bulk update/delete. Fixes: #5018 Fixes: #3903 Change-Id: Id90827cc7e2bc713d1255127f908c8e133de9295 --- diff --git a/doc/build/changelog/unreleased_14/5018.rst b/doc/build/changelog/unreleased_14/5018.rst new file mode 100644 index 0000000000..355bb1cc56 --- /dev/null +++ b/doc/build/changelog/unreleased_14/5018.rst @@ -0,0 +1,12 @@ +.. change:: + :tags: usecase, orm + :tickets: 5018, 3903 + + The ORM bulk update and delete operations, historically available via the + :meth:`_orm.Query.update` and :meth:`_orm.Query.delete` methods as well as + via the :class:`_dml.Update` and :class:`_dml.Delete` constructs for + :term:`2.0 style` execution, will now automatically accommodate for the + additional WHERE criteria needed for a single-table inheritance + discriminator in order to limit the statement to rows referring to the + specific subtype requested. The new :func:`_orm.with_loader_criteria` + construct is also supported for with bulk update/delete operations. diff --git a/doc/build/orm/session_basics.rst b/doc/build/orm/session_basics.rst index 78f8f3234f..110e3df131 100644 --- a/doc/build/orm/session_basics.rst +++ b/doc/build/orm/session_basics.rst @@ -507,10 +507,19 @@ values for ``synchronize_session`` are supported: Similar guidelines as those detailed at :ref:`multi_table_updates` may be applied. - * The polymorphic identity WHERE criteria is **not** included - for single- or - joined- table updates - this must be added **manually**, even - for single table inheritance. + * The WHERE criteria needed in order to limit the polymorphic identity to + specific subclasses for single-table-inheritance mappings **is included + automatically** . This only applies to a subclass mapper that has no + table of its own. + + .. versionchanged:: 1.4 ORM updates/deletes now automatically + accommodate for the WHERE criteria added for single-inheritance + mappings. + + * The :func:`_orm.with_loader_criteria` option **is supported** by ORM + update and delete operations; criteria here will be added to that of the + UPDATE or DELETE statement being emitted, as well as taken into account + during the "synchronize" process. * In order to intercept bulk UPDATE and DELETE operations with event handlers, use the :meth:`_orm.SessionEvents.do_orm_execute` event. diff --git a/lib/sqlalchemy/orm/interfaces.py b/lib/sqlalchemy/orm/interfaces.py index 068c850732..b1ff1a0497 100644 --- a/lib/sqlalchemy/orm/interfaces.py +++ b/lib/sqlalchemy/orm/interfaces.py @@ -729,6 +729,8 @@ class ORMOption(ExecutableOption): _is_compile_state = False + _is_criteria_option = False + class LoaderOption(ORMOption): """Describe a loader modification to an ORM statement at compilation time. @@ -743,6 +745,27 @@ class LoaderOption(ORMOption): """Apply a modification to a given :class:`.CompileState`.""" +class CriteriaOption(ORMOption): + """Describe a WHERE criteria modification to an ORM statement at + compilation time. + + .. versionadded:: 1.4 + + """ + + _is_compile_state = True + _is_criteria_option = True + + def process_compile_state(self, compile_state): + """Apply a modification to a given :class:`.CompileState`.""" + + def get_global_criteria(self, attributes): + """update additional entity criteria options in the given + attributes dictionary. + + """ + + class UserDefinedOption(ORMOption): """Base class for a user-defined option that can be consumed from the :meth:`.SessionEvents.do_orm_execute` event hook. diff --git a/lib/sqlalchemy/orm/persistence.py b/lib/sqlalchemy/orm/persistence.py index 49b29a6bc0..d05381c1d2 100644 --- a/lib/sqlalchemy/orm/persistence.py +++ b/lib/sqlalchemy/orm/persistence.py @@ -1856,6 +1856,43 @@ class BulkUDCompileState(CompileState): return result + @classmethod + def _adjust_for_extra_criteria(cls, global_attributes, ext_info): + """Apply extra criteria filtering. + + For all distinct single-table-inheritance mappers represented in the + table being updated or deleted, produce additional WHERE criteria such + that only the appropriate subtypes are selected from the total results. + + Additionally, add WHERE criteria originating from LoaderCriteriaOptions + collected from the statement. + + """ + + return_crit = () + + adapter = ext_info._adapter if ext_info.is_aliased_class else None + + if ( + "additional_entity_criteria", + ext_info.mapper, + ) in global_attributes: + return_crit += tuple( + ae._resolve_where_criteria(ext_info) + for ae in global_attributes[ + ("additional_entity_criteria", ext_info.mapper) + ] + if ae.include_aliases or ae.entity is ext_info + ) + + if ext_info.mapper._single_table_criterion is not None: + return_crit += (ext_info.mapper._single_table_criterion,) + + if adapter: + return_crit = tuple(adapter.traverse(crit) for crit in return_crit) + + return return_crit + @classmethod def _do_pre_synchronize_evaluate( cls, @@ -1873,10 +1910,22 @@ class BulkUDCompileState(CompileState): try: evaluator_compiler = evaluator.EvaluatorCompiler(target_cls) + crit = () if statement._where_criteria: - eval_condition = evaluator_compiler.process( - *statement._where_criteria + crit += statement._where_criteria + + global_attributes = {} + for opt in statement._with_options: + if opt._is_criteria_option: + opt.get_global_criteria(global_attributes) + + if global_attributes: + crit += cls._adjust_for_extra_criteria( + global_attributes, mapper ) + + if crit: + eval_condition = evaluator_compiler.process(*crit) else: def eval_condition(obj): @@ -1920,16 +1969,17 @@ class BulkUDCompileState(CompileState): # TODO: detect when the where clause is a trivial primary key match. matched_objects = [ - obj - for (cls, pk, identity_token,), obj in session.identity_map.items() - if issubclass(cls, target_cls) - and eval_condition(obj) + state.obj() + for state in session.identity_map.all_states() + if state.mapper.isa(mapper) + and eval_condition(state.obj()) and ( update_options._refresh_identity_token is None # TODO: coverage for the case where horiziontal sharding # invokes an update() or delete() given an explicit identity # token up front - or identity_token == update_options._refresh_identity_token + or state.identity_token + == update_options._refresh_identity_token ) ] return update_options + { @@ -2003,8 +2053,10 @@ class BulkUDCompileState(CompileState): ): mapper = update_options._subject_mapper - select_stmt = select( - *(mapper.primary_key + (mapper.select_identity_token,)) + select_stmt = ( + select(*(mapper.primary_key + (mapper.select_identity_token,))) + .select_from(mapper) + .options(*statement._with_options) ) select_stmt._where_criteria = statement._where_criteria @@ -2075,12 +2127,20 @@ class BulkORMUpdate(UpdateDMLState, BulkUDCompileState): self = cls.__new__(cls) - self.mapper = mapper = statement.table._annotations.get( - "parentmapper", None - ) + ext_info = statement.table._annotations["parententity"] + + self.mapper = mapper = ext_info.mapper + + self.extra_criteria_entities = {} self._resolved_values = cls._get_resolved_values(mapper, statement) + extra_criteria_attributes = {} + + for opt in statement._with_options: + if opt._is_criteria_option: + opt.get_global_criteria(extra_criteria_attributes) + if not statement._preserve_parameter_order and statement._values: self._resolved_values = dict(self._resolved_values) @@ -2097,6 +2157,12 @@ class BulkORMUpdate(UpdateDMLState, BulkUDCompileState): elif statement._values: new_stmt._values = self._resolved_values + new_crit = cls._adjust_for_extra_criteria( + extra_criteria_attributes, mapper + ) + if new_crit: + new_stmt = new_stmt.where(*new_crit) + # if we are against a lambda statement we might not be the # topmost object that received per-execute annotations top_level_stmt = compiler.statement @@ -2211,11 +2277,25 @@ class BulkORMDelete(DeleteDMLState, BulkUDCompileState): def create_for_statement(cls, statement, compiler, **kw): self = cls.__new__(cls) - self.mapper = mapper = statement.table._annotations.get( - "parentmapper", None - ) + ext_info = statement.table._annotations["parententity"] + self.mapper = mapper = ext_info.mapper top_level_stmt = compiler.statement + + self.extra_criteria_entities = {} + + extra_criteria_attributes = {} + + for opt in statement._with_options: + if opt._is_criteria_option: + opt.get_global_criteria(extra_criteria_attributes) + + new_crit = cls._adjust_for_extra_criteria( + extra_criteria_attributes, mapper + ) + if new_crit: + statement = statement.where(*new_crit) + if ( mapper and top_level_stmt._annotations.get("synchronize_session", None) diff --git a/lib/sqlalchemy/orm/util.py b/lib/sqlalchemy/orm/util.py index 82fad0815d..271a441f00 100644 --- a/lib/sqlalchemy/orm/util.py +++ b/lib/sqlalchemy/orm/util.py @@ -23,7 +23,7 @@ from .base import object_state # noqa from .base import state_attribute_str # noqa from .base import state_class_str # noqa from .base import state_str # noqa -from .interfaces import LoaderOption +from .interfaces import CriteriaOption from .interfaces import MapperProperty # noqa from .interfaces import ORMColumnsClauseRole from .interfaces import ORMEntityColumnsClauseRole @@ -856,7 +856,7 @@ class AliasedInsp( return "aliased(%s)" % (self._target.__name__,) -class LoaderCriteriaOption(LoaderOption): +class LoaderCriteriaOption(CriteriaOption): """Add additional WHERE criteria to the load for all occurrences of a particular entity. @@ -1026,8 +1026,11 @@ class LoaderCriteriaOption(LoaderOption): # if options to limit the criteria to immediate query only, # use compile_state.attributes instead + self.get_global_criteria(compile_state.global_attributes) + + def get_global_criteria(self, attributes): for mp in self._all_mappers(): - load_criteria = compile_state.global_attributes.setdefault( + load_criteria = attributes.setdefault( ("additional_entity_criteria", mp), [] ) diff --git a/lib/sqlalchemy/sql/dml.py b/lib/sqlalchemy/sql/dml.py index a9bccaeff8..b7151ac7b0 100644 --- a/lib/sqlalchemy/sql/dml.py +++ b/lib/sqlalchemy/sql/dml.py @@ -12,6 +12,7 @@ Provide :class:`_expression.Insert`, :class:`_expression.Update` and from sqlalchemy.types import NullType from . import coercions from . import roles +from .base import _entity_namespace_key from .base import _from_objects from .base import _generative from .base import ColumnCollection @@ -983,10 +984,30 @@ class DMLWhereBase(object): ) def filter(self, *criteria): - """A synonym for the :meth:`_dml.DMLWhereBase.where` method.""" + """A synonym for the :meth:`_dml.DMLWhereBase.where` method. + + .. versionadded:: 1.4 + + """ return self.where(*criteria) + def _filter_by_zero(self): + return self.table + + def filter_by(self, **kwargs): + r"""apply the given filtering criterion as a WHERE clause + to this select. + + """ + from_entity = self._filter_by_zero() + + clauses = [ + _entity_namespace_key(from_entity, key) == value + for key, value in kwargs.items() + ] + return self.filter(*clauses) + @property def whereclause(self): """Return the completed WHERE clause for this :class:`.DMLWhereBase` diff --git a/test/orm/test_update_delete.py b/test/orm/test_update_delete.py index 5360aecb70..aec5d05534 100644 --- a/test/orm/test_update_delete.py +++ b/test/orm/test_update_delete.py @@ -19,7 +19,9 @@ from sqlalchemy.orm import joinedload from sqlalchemy.orm import mapper from sqlalchemy.orm import relationship from sqlalchemy.orm import Session +from sqlalchemy.orm import sessionmaker from sqlalchemy.orm import synonym +from sqlalchemy.orm import with_loader_criteria from sqlalchemy.testing import assert_raises from sqlalchemy.testing import assert_raises_message from sqlalchemy.testing import eq_ @@ -540,6 +542,68 @@ class UpdateDeleteTest(fixtures.MappedTest): list(zip([15, 27, 19, 27])), ) + @testing.combinations( + ("fetch", False), + ("fetch", True), + ("evaluate", False), + ("evaluate", True), + ) + def test_update_with_loader_criteria(self, fetchstyle, future): + User = self.classes.User + + sess = Session(testing.db, future=True) + + john, jack, jill, jane = ( + sess.execute(select(User).order_by(User.id)).scalars().all() + ) + + sess.execute( + update(User) + .options( + with_loader_criteria(User, User.name.in_(["jill", "jane"])) + ) + .where(User.age > 29) + .values(age=User.age - 10) + .execution_options(synchronize_session=fetchstyle) + ) + + eq_([john.age, jack.age, jill.age, jane.age], [25, 47, 29, 27]) + eq_( + sess.execute(select(User.age).order_by(User.id)).all(), + list(zip([25, 47, 29, 27])), + ) + + @testing.combinations( + ("fetch", False), + ("fetch", True), + ("evaluate", False), + ("evaluate", True), + ) + def test_delete_with_loader_criteria(self, fetchstyle, future): + User = self.classes.User + + sess = Session(testing.db, future=True) + + john, jack, jill, jane = ( + sess.execute(select(User).order_by(User.id)).scalars().all() + ) + + sess.execute( + delete(User) + .options( + with_loader_criteria(User, User.name.in_(["jill", "jane"])) + ) + .where(User.age > 29) + .execution_options(synchronize_session=fetchstyle) + ) + + assert jane not in sess + assert jack in sess + eq_( + sess.execute(select(User.age).order_by(User.id)).all(), + list(zip([25, 47, 29])), + ) + def test_update_against_table_col(self): User, users = self.classes.User, self.tables.users @@ -1646,3 +1710,136 @@ class InheritTest(fixtures.DeclarativeMappedTest): set(s.query(Person.name, Engineer.engineer_name)), set([("e1", "e1"), ("e22", "e55")]), ) + + +class SingleTablePolymorphicTest(fixtures.DeclarativeMappedTest): + __backend__ = True + + @classmethod + def setup_classes(cls): + Base = cls.DeclarativeBasic + + class Staff(Base): + __tablename__ = "staff" + position = Column(String(10), nullable=False) + id = Column( + Integer, primary_key=True, test_needs_autoincrement=True + ) + name = Column(String(5)) + stats = Column(String(5)) + __mapper_args__ = {"polymorphic_on": position} + + class Sales(Staff): + sales_stats = Column(String(5)) + __mapper_args__ = {"polymorphic_identity": "sales"} + + class Support(Staff): + support_stats = Column(String(5)) + __mapper_args__ = {"polymorphic_identity": "support"} + + @classmethod + def insert_data(cls, connection): + with sessionmaker(connection).begin() as session: + Sales, Support = ( + cls.classes.Sales, + cls.classes.Support, + ) + session.add_all( + [ + Sales(name="n1", sales_stats="1", stats="a"), + Sales(name="n2", sales_stats="2", stats="b"), + Support(name="n1", support_stats="3", stats="c"), + Support(name="n2", support_stats="4", stats="d"), + ] + ) + + @testing.combinations( + ("fetch", False), + ("fetch", True), + ("evaluate", False), + ("evaluate", True), + ) + def test_update(self, fetchstyle, future): + Staff, Sales, Support = self.classes("Staff", "Sales", "Support") + + sess = Session() + + en1, en2 = ( + sess.execute(select(Sales).order_by(Sales.sales_stats)) + .scalars() + .all() + ) + mn1, mn2 = ( + sess.execute(select(Support).order_by(Support.support_stats)) + .scalars() + .all() + ) + + if future: + sess.execute( + update(Sales) + .filter_by(name="n1") + .values(stats="p") + .execution_options(synchronize_session=fetchstyle) + ) + else: + sess.query(Sales).filter_by(name="n1").update( + {"stats": "p"}, synchronize_session=fetchstyle + ) + + eq_(en1.stats, "p") + eq_(mn1.stats, "c") + eq_( + sess.execute( + select(Staff.position, Staff.name, Staff.stats).order_by( + Staff.id + ) + ).all(), + [ + ("sales", "n1", "p"), + ("sales", "n2", "b"), + ("support", "n1", "c"), + ("support", "n2", "d"), + ], + ) + + @testing.combinations( + ("fetch", False), + ("fetch", True), + ("evaluate", False), + ("evaluate", True), + ) + def test_delete(self, fetchstyle, future): + Staff, Sales, Support = self.classes("Staff", "Sales", "Support") + + sess = Session() + en1, en2 = sess.query(Sales).order_by(Sales.sales_stats).all() + mn1, mn2 = sess.query(Support).order_by(Support.support_stats).all() + + if future: + sess.execute( + delete(Sales) + .filter_by(name="n1") + .execution_options(synchronize_session=fetchstyle) + ) + else: + sess.query(Sales).filter_by(name="n1").delete( + synchronize_session=fetchstyle + ) + assert en1 not in sess + assert en2 in sess + assert mn1 in sess + assert mn2 in sess + + eq_( + sess.execute( + select(Staff.position, Staff.name, Staff.stats).order_by( + Staff.id + ) + ).all(), + [ + ("sales", "n2", "b"), + ("support", "n1", "c"), + ("support", "n2", "d"), + ], + )