]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Support extra / single inh criteria with ORM update/delete
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 30 Aug 2020 22:13:36 +0000 (18:13 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 30 Aug 2020 23:45:04 +0000 (19:45 -0400)
The ORM bulk update and delete operations, historically available via the
:meth:`_orm.Query.update` and :meth:`_orm.Query.delete` methods as well as
via the :class:`_dml.Update` and :class:`_dml.Delete` constructs for
:term:`2.0 style` execution, will now automatically accommodate for the
additional WHERE criteria needed for a single-table inheritance
discrminiator.   Joined-table inheritance is still not directly
supported. The new :func:`_orm.with_loader_criteria` construct is also
supported for all mappings with bulk update/delete.

Fixes: #5018
Fixes: #3903
Change-Id: Id90827cc7e2bc713d1255127f908c8e133de9295

doc/build/changelog/unreleased_14/5018.rst [new file with mode: 0644]
doc/build/orm/session_basics.rst
lib/sqlalchemy/orm/interfaces.py
lib/sqlalchemy/orm/persistence.py
lib/sqlalchemy/orm/util.py
lib/sqlalchemy/sql/dml.py
test/orm/test_update_delete.py

diff --git a/doc/build/changelog/unreleased_14/5018.rst b/doc/build/changelog/unreleased_14/5018.rst
new file mode 100644 (file)
index 0000000..355bb1c
--- /dev/null
@@ -0,0 +1,12 @@
+.. change::
+    :tags: usecase, orm
+    :tickets: 5018, 3903
+
+    The ORM bulk update and delete operations, historically available via the
+    :meth:`_orm.Query.update` and :meth:`_orm.Query.delete` methods as well as
+    via the :class:`_dml.Update` and :class:`_dml.Delete` constructs for
+    :term:`2.0 style` execution, will now automatically accommodate for the
+    additional WHERE criteria needed for a single-table inheritance
+    discriminator in order to limit the statement to rows referring to the
+    specific subtype requested.   The new :func:`_orm.with_loader_criteria`
+    construct is also supported for with bulk update/delete operations.
index 78f8f3234fb58ed33251f6755c7dec83c29faa62..110e3df1312dd6987f73146d883ec79fe98d79db 100644 (file)
@@ -507,10 +507,19 @@ values for ``synchronize_session`` are supported:
       Similar guidelines as those detailed at :ref:`multi_table_updates`
       may be applied.
 
-    * The polymorphic identity WHERE criteria is **not** included
-      for single- or
-      joined- table updates - this must be added **manually**, even
-      for single table inheritance.
+    * The WHERE criteria needed in order to limit the polymorphic identity to
+      specific subclasses for single-table-inheritance mappings **is included
+      automatically** .   This only applies to a subclass mapper that has no
+      table of its own.
+
+      .. versionchanged:: 1.4  ORM updates/deletes now automatically
+         accommodate for the WHERE criteria added for single-inheritance
+         mappings.
+
+    * The :func:`_orm.with_loader_criteria` option **is supported** by ORM
+      update and delete operations; criteria here will be added to that of the
+      UPDATE or DELETE statement being emitted, as well as taken into account
+      during the "synchronize" process.
 
     * In order to intercept bulk UPDATE and DELETE operations with event
       handlers, use the :meth:`_orm.SessionEvents.do_orm_execute` event.
index 068c8507322205387bc78ce9b36fba47688fe707..b1ff1a0497bd195653beb902f0fb995fe3a28784 100644 (file)
@@ -729,6 +729,8 @@ class ORMOption(ExecutableOption):
 
     _is_compile_state = False
 
+    _is_criteria_option = False
+
 
 class LoaderOption(ORMOption):
     """Describe a loader modification to an ORM statement at compilation time.
@@ -743,6 +745,27 @@ class LoaderOption(ORMOption):
         """Apply a modification to a given :class:`.CompileState`."""
 
 
+class CriteriaOption(ORMOption):
+    """Describe a WHERE criteria modification to an ORM statement at
+    compilation time.
+
+    .. versionadded:: 1.4
+
+    """
+
+    _is_compile_state = True
+    _is_criteria_option = True
+
+    def process_compile_state(self, compile_state):
+        """Apply a modification to a given :class:`.CompileState`."""
+
+    def get_global_criteria(self, attributes):
+        """update additional entity criteria options in the given
+        attributes dictionary.
+
+        """
+
+
 class UserDefinedOption(ORMOption):
     """Base class for a user-defined option that can be consumed from the
     :meth:`.SessionEvents.do_orm_execute` event hook.
index 49b29a6bc0fa1cb38a899a3f80414f689ba8329a..d05381c1d25fe9bd8867da27b5a3aaad94fa3e06 100644 (file)
@@ -1856,6 +1856,43 @@ class BulkUDCompileState(CompileState):
 
         return result
 
+    @classmethod
+    def _adjust_for_extra_criteria(cls, global_attributes, ext_info):
+        """Apply extra criteria filtering.
+
+        For all distinct single-table-inheritance mappers represented in the
+        table being updated or deleted, produce additional WHERE criteria such
+        that only the appropriate subtypes are selected from the total results.
+
+        Additionally, add WHERE criteria originating from LoaderCriteriaOptions
+        collected from the statement.
+
+        """
+
+        return_crit = ()
+
+        adapter = ext_info._adapter if ext_info.is_aliased_class else None
+
+        if (
+            "additional_entity_criteria",
+            ext_info.mapper,
+        ) in global_attributes:
+            return_crit += tuple(
+                ae._resolve_where_criteria(ext_info)
+                for ae in global_attributes[
+                    ("additional_entity_criteria", ext_info.mapper)
+                ]
+                if ae.include_aliases or ae.entity is ext_info
+            )
+
+        if ext_info.mapper._single_table_criterion is not None:
+            return_crit += (ext_info.mapper._single_table_criterion,)
+
+        if adapter:
+            return_crit = tuple(adapter.traverse(crit) for crit in return_crit)
+
+        return return_crit
+
     @classmethod
     def _do_pre_synchronize_evaluate(
         cls,
@@ -1873,10 +1910,22 @@ class BulkUDCompileState(CompileState):
 
         try:
             evaluator_compiler = evaluator.EvaluatorCompiler(target_cls)
+            crit = ()
             if statement._where_criteria:
-                eval_condition = evaluator_compiler.process(
-                    *statement._where_criteria
+                crit += statement._where_criteria
+
+            global_attributes = {}
+            for opt in statement._with_options:
+                if opt._is_criteria_option:
+                    opt.get_global_criteria(global_attributes)
+
+            if global_attributes:
+                crit += cls._adjust_for_extra_criteria(
+                    global_attributes, mapper
                 )
+
+            if crit:
+                eval_condition = evaluator_compiler.process(*crit)
             else:
 
                 def eval_condition(obj):
@@ -1920,16 +1969,17 @@ class BulkUDCompileState(CompileState):
 
         # TODO: detect when the where clause is a trivial primary key match.
         matched_objects = [
-            obj
-            for (cls, pk, identity_token,), obj in session.identity_map.items()
-            if issubclass(cls, target_cls)
-            and eval_condition(obj)
+            state.obj()
+            for state in session.identity_map.all_states()
+            if state.mapper.isa(mapper)
+            and eval_condition(state.obj())
             and (
                 update_options._refresh_identity_token is None
                 # TODO: coverage for the case where horiziontal sharding
                 # invokes an update() or delete() given an explicit identity
                 # token up front
-                or identity_token == update_options._refresh_identity_token
+                or state.identity_token
+                == update_options._refresh_identity_token
             )
         ]
         return update_options + {
@@ -2003,8 +2053,10 @@ class BulkUDCompileState(CompileState):
     ):
         mapper = update_options._subject_mapper
 
-        select_stmt = select(
-            *(mapper.primary_key + (mapper.select_identity_token,))
+        select_stmt = (
+            select(*(mapper.primary_key + (mapper.select_identity_token,)))
+            .select_from(mapper)
+            .options(*statement._with_options)
         )
         select_stmt._where_criteria = statement._where_criteria
 
@@ -2075,12 +2127,20 @@ class BulkORMUpdate(UpdateDMLState, BulkUDCompileState):
 
         self = cls.__new__(cls)
 
-        self.mapper = mapper = statement.table._annotations.get(
-            "parentmapper", None
-        )
+        ext_info = statement.table._annotations["parententity"]
+
+        self.mapper = mapper = ext_info.mapper
+
+        self.extra_criteria_entities = {}
 
         self._resolved_values = cls._get_resolved_values(mapper, statement)
 
+        extra_criteria_attributes = {}
+
+        for opt in statement._with_options:
+            if opt._is_criteria_option:
+                opt.get_global_criteria(extra_criteria_attributes)
+
         if not statement._preserve_parameter_order and statement._values:
             self._resolved_values = dict(self._resolved_values)
 
@@ -2097,6 +2157,12 @@ class BulkORMUpdate(UpdateDMLState, BulkUDCompileState):
         elif statement._values:
             new_stmt._values = self._resolved_values
 
+        new_crit = cls._adjust_for_extra_criteria(
+            extra_criteria_attributes, mapper
+        )
+        if new_crit:
+            new_stmt = new_stmt.where(*new_crit)
+
         # if we are against a lambda statement we might not be the
         # topmost object that received per-execute annotations
         top_level_stmt = compiler.statement
@@ -2211,11 +2277,25 @@ class BulkORMDelete(DeleteDMLState, BulkUDCompileState):
     def create_for_statement(cls, statement, compiler, **kw):
         self = cls.__new__(cls)
 
-        self.mapper = mapper = statement.table._annotations.get(
-            "parentmapper", None
-        )
+        ext_info = statement.table._annotations["parententity"]
+        self.mapper = mapper = ext_info.mapper
 
         top_level_stmt = compiler.statement
+
+        self.extra_criteria_entities = {}
+
+        extra_criteria_attributes = {}
+
+        for opt in statement._with_options:
+            if opt._is_criteria_option:
+                opt.get_global_criteria(extra_criteria_attributes)
+
+        new_crit = cls._adjust_for_extra_criteria(
+            extra_criteria_attributes, mapper
+        )
+        if new_crit:
+            statement = statement.where(*new_crit)
+
         if (
             mapper
             and top_level_stmt._annotations.get("synchronize_session", None)
index 82fad0815dbcac5c2dbd0ee8727429e51d7aaede..271a441f00b43ad77f41e12f05e5e6d4ebdfe472 100644 (file)
@@ -23,7 +23,7 @@ from .base import object_state  # noqa
 from .base import state_attribute_str  # noqa
 from .base import state_class_str  # noqa
 from .base import state_str  # noqa
-from .interfaces import LoaderOption
+from .interfaces import CriteriaOption
 from .interfaces import MapperProperty  # noqa
 from .interfaces import ORMColumnsClauseRole
 from .interfaces import ORMEntityColumnsClauseRole
@@ -856,7 +856,7 @@ class AliasedInsp(
             return "aliased(%s)" % (self._target.__name__,)
 
 
-class LoaderCriteriaOption(LoaderOption):
+class LoaderCriteriaOption(CriteriaOption):
     """Add additional WHERE criteria to the load for all occurrences of
     a particular entity.
 
@@ -1026,8 +1026,11 @@ class LoaderCriteriaOption(LoaderOption):
         # if options to limit the criteria to immediate query only,
         # use compile_state.attributes instead
 
+        self.get_global_criteria(compile_state.global_attributes)
+
+    def get_global_criteria(self, attributes):
         for mp in self._all_mappers():
-            load_criteria = compile_state.global_attributes.setdefault(
+            load_criteria = attributes.setdefault(
                 ("additional_entity_criteria", mp), []
             )
 
index a9bccaeff88cfac3efd621fef8d0ae521ec6225e..b7151ac7b04ada245f03aa7b0dacf8153edefa7e 100644 (file)
@@ -12,6 +12,7 @@ Provide :class:`_expression.Insert`, :class:`_expression.Update` and
 from sqlalchemy.types import NullType
 from . import coercions
 from . import roles
+from .base import _entity_namespace_key
 from .base import _from_objects
 from .base import _generative
 from .base import ColumnCollection
@@ -983,10 +984,30 @@ class DMLWhereBase(object):
         )
 
     def filter(self, *criteria):
-        """A synonym for the :meth:`_dml.DMLWhereBase.where` method."""
+        """A synonym for the :meth:`_dml.DMLWhereBase.where` method.
+
+        .. versionadded:: 1.4
+
+        """
 
         return self.where(*criteria)
 
+    def _filter_by_zero(self):
+        return self.table
+
+    def filter_by(self, **kwargs):
+        r"""apply the given filtering criterion as a WHERE clause
+        to this select.
+
+        """
+        from_entity = self._filter_by_zero()
+
+        clauses = [
+            _entity_namespace_key(from_entity, key) == value
+            for key, value in kwargs.items()
+        ]
+        return self.filter(*clauses)
+
     @property
     def whereclause(self):
         """Return the completed WHERE clause for this :class:`.DMLWhereBase`
index 5360aecb70e07b15cf52b1ef66ab7ae33455a078..aec5d05534369a8a57c1fd289139560a47a305b8 100644 (file)
@@ -19,7 +19,9 @@ from sqlalchemy.orm import joinedload
 from sqlalchemy.orm import mapper
 from sqlalchemy.orm import relationship
 from sqlalchemy.orm import Session
+from sqlalchemy.orm import sessionmaker
 from sqlalchemy.orm import synonym
+from sqlalchemy.orm import with_loader_criteria
 from sqlalchemy.testing import assert_raises
 from sqlalchemy.testing import assert_raises_message
 from sqlalchemy.testing import eq_
@@ -540,6 +542,68 @@ class UpdateDeleteTest(fixtures.MappedTest):
             list(zip([15, 27, 19, 27])),
         )
 
+    @testing.combinations(
+        ("fetch", False),
+        ("fetch", True),
+        ("evaluate", False),
+        ("evaluate", True),
+    )
+    def test_update_with_loader_criteria(self, fetchstyle, future):
+        User = self.classes.User
+
+        sess = Session(testing.db, future=True)
+
+        john, jack, jill, jane = (
+            sess.execute(select(User).order_by(User.id)).scalars().all()
+        )
+
+        sess.execute(
+            update(User)
+            .options(
+                with_loader_criteria(User, User.name.in_(["jill", "jane"]))
+            )
+            .where(User.age > 29)
+            .values(age=User.age - 10)
+            .execution_options(synchronize_session=fetchstyle)
+        )
+
+        eq_([john.age, jack.age, jill.age, jane.age], [25, 47, 29, 27])
+        eq_(
+            sess.execute(select(User.age).order_by(User.id)).all(),
+            list(zip([25, 47, 29, 27])),
+        )
+
+    @testing.combinations(
+        ("fetch", False),
+        ("fetch", True),
+        ("evaluate", False),
+        ("evaluate", True),
+    )
+    def test_delete_with_loader_criteria(self, fetchstyle, future):
+        User = self.classes.User
+
+        sess = Session(testing.db, future=True)
+
+        john, jack, jill, jane = (
+            sess.execute(select(User).order_by(User.id)).scalars().all()
+        )
+
+        sess.execute(
+            delete(User)
+            .options(
+                with_loader_criteria(User, User.name.in_(["jill", "jane"]))
+            )
+            .where(User.age > 29)
+            .execution_options(synchronize_session=fetchstyle)
+        )
+
+        assert jane not in sess
+        assert jack in sess
+        eq_(
+            sess.execute(select(User.age).order_by(User.id)).all(),
+            list(zip([25, 47, 29])),
+        )
+
     def test_update_against_table_col(self):
         User, users = self.classes.User, self.tables.users
 
@@ -1646,3 +1710,136 @@ class InheritTest(fixtures.DeclarativeMappedTest):
             set(s.query(Person.name, Engineer.engineer_name)),
             set([("e1", "e1"), ("e22", "e55")]),
         )
+
+
+class SingleTablePolymorphicTest(fixtures.DeclarativeMappedTest):
+    __backend__ = True
+
+    @classmethod
+    def setup_classes(cls):
+        Base = cls.DeclarativeBasic
+
+        class Staff(Base):
+            __tablename__ = "staff"
+            position = Column(String(10), nullable=False)
+            id = Column(
+                Integer, primary_key=True, test_needs_autoincrement=True
+            )
+            name = Column(String(5))
+            stats = Column(String(5))
+            __mapper_args__ = {"polymorphic_on": position}
+
+        class Sales(Staff):
+            sales_stats = Column(String(5))
+            __mapper_args__ = {"polymorphic_identity": "sales"}
+
+        class Support(Staff):
+            support_stats = Column(String(5))
+            __mapper_args__ = {"polymorphic_identity": "support"}
+
+    @classmethod
+    def insert_data(cls, connection):
+        with sessionmaker(connection).begin() as session:
+            Sales, Support = (
+                cls.classes.Sales,
+                cls.classes.Support,
+            )
+            session.add_all(
+                [
+                    Sales(name="n1", sales_stats="1", stats="a"),
+                    Sales(name="n2", sales_stats="2", stats="b"),
+                    Support(name="n1", support_stats="3", stats="c"),
+                    Support(name="n2", support_stats="4", stats="d"),
+                ]
+            )
+
+    @testing.combinations(
+        ("fetch", False),
+        ("fetch", True),
+        ("evaluate", False),
+        ("evaluate", True),
+    )
+    def test_update(self, fetchstyle, future):
+        Staff, Sales, Support = self.classes("Staff", "Sales", "Support")
+
+        sess = Session()
+
+        en1, en2 = (
+            sess.execute(select(Sales).order_by(Sales.sales_stats))
+            .scalars()
+            .all()
+        )
+        mn1, mn2 = (
+            sess.execute(select(Support).order_by(Support.support_stats))
+            .scalars()
+            .all()
+        )
+
+        if future:
+            sess.execute(
+                update(Sales)
+                .filter_by(name="n1")
+                .values(stats="p")
+                .execution_options(synchronize_session=fetchstyle)
+            )
+        else:
+            sess.query(Sales).filter_by(name="n1").update(
+                {"stats": "p"}, synchronize_session=fetchstyle
+            )
+
+        eq_(en1.stats, "p")
+        eq_(mn1.stats, "c")
+        eq_(
+            sess.execute(
+                select(Staff.position, Staff.name, Staff.stats).order_by(
+                    Staff.id
+                )
+            ).all(),
+            [
+                ("sales", "n1", "p"),
+                ("sales", "n2", "b"),
+                ("support", "n1", "c"),
+                ("support", "n2", "d"),
+            ],
+        )
+
+    @testing.combinations(
+        ("fetch", False),
+        ("fetch", True),
+        ("evaluate", False),
+        ("evaluate", True),
+    )
+    def test_delete(self, fetchstyle, future):
+        Staff, Sales, Support = self.classes("Staff", "Sales", "Support")
+
+        sess = Session()
+        en1, en2 = sess.query(Sales).order_by(Sales.sales_stats).all()
+        mn1, mn2 = sess.query(Support).order_by(Support.support_stats).all()
+
+        if future:
+            sess.execute(
+                delete(Sales)
+                .filter_by(name="n1")
+                .execution_options(synchronize_session=fetchstyle)
+            )
+        else:
+            sess.query(Sales).filter_by(name="n1").delete(
+                synchronize_session=fetchstyle
+            )
+        assert en1 not in sess
+        assert en2 in sess
+        assert mn1 in sess
+        assert mn2 in sess
+
+        eq_(
+            sess.execute(
+                select(Staff.position, Staff.name, Staff.stats).order_by(
+                    Staff.id
+                )
+            ).all(),
+            [
+                ("sales", "n2", "b"),
+                ("support", "n1", "c"),
+                ("support", "n2", "d"),
+            ],
+        )