--- /dev/null
+.. change::
+ :tags: usecase, orm
+ :tickets: 5018, 3903
+
+ The ORM bulk update and delete operations, historically available via the
+ :meth:`_orm.Query.update` and :meth:`_orm.Query.delete` methods as well as
+ via the :class:`_dml.Update` and :class:`_dml.Delete` constructs for
+ :term:`2.0 style` execution, will now automatically accommodate for the
+ additional WHERE criteria needed for a single-table inheritance
+ discriminator in order to limit the statement to rows referring to the
+ specific subtype requested. The new :func:`_orm.with_loader_criteria`
+ construct is also supported for with bulk update/delete operations.
Similar guidelines as those detailed at :ref:`multi_table_updates`
may be applied.
- * The polymorphic identity WHERE criteria is **not** included
- for single- or
- joined- table updates - this must be added **manually**, even
- for single table inheritance.
+ * The WHERE criteria needed in order to limit the polymorphic identity to
+ specific subclasses for single-table-inheritance mappings **is included
+ automatically** . This only applies to a subclass mapper that has no
+ table of its own.
+
+ .. versionchanged:: 1.4 ORM updates/deletes now automatically
+ accommodate for the WHERE criteria added for single-inheritance
+ mappings.
+
+ * The :func:`_orm.with_loader_criteria` option **is supported** by ORM
+ update and delete operations; criteria here will be added to that of the
+ UPDATE or DELETE statement being emitted, as well as taken into account
+ during the "synchronize" process.
* In order to intercept bulk UPDATE and DELETE operations with event
handlers, use the :meth:`_orm.SessionEvents.do_orm_execute` event.
_is_compile_state = False
+ _is_criteria_option = False
+
class LoaderOption(ORMOption):
"""Describe a loader modification to an ORM statement at compilation time.
"""Apply a modification to a given :class:`.CompileState`."""
+class CriteriaOption(ORMOption):
+ """Describe a WHERE criteria modification to an ORM statement at
+ compilation time.
+
+ .. versionadded:: 1.4
+
+ """
+
+ _is_compile_state = True
+ _is_criteria_option = True
+
+ def process_compile_state(self, compile_state):
+ """Apply a modification to a given :class:`.CompileState`."""
+
+ def get_global_criteria(self, attributes):
+ """update additional entity criteria options in the given
+ attributes dictionary.
+
+ """
+
+
class UserDefinedOption(ORMOption):
"""Base class for a user-defined option that can be consumed from the
:meth:`.SessionEvents.do_orm_execute` event hook.
return result
+ @classmethod
+ def _adjust_for_extra_criteria(cls, global_attributes, ext_info):
+ """Apply extra criteria filtering.
+
+ For all distinct single-table-inheritance mappers represented in the
+ table being updated or deleted, produce additional WHERE criteria such
+ that only the appropriate subtypes are selected from the total results.
+
+ Additionally, add WHERE criteria originating from LoaderCriteriaOptions
+ collected from the statement.
+
+ """
+
+ return_crit = ()
+
+ adapter = ext_info._adapter if ext_info.is_aliased_class else None
+
+ if (
+ "additional_entity_criteria",
+ ext_info.mapper,
+ ) in global_attributes:
+ return_crit += tuple(
+ ae._resolve_where_criteria(ext_info)
+ for ae in global_attributes[
+ ("additional_entity_criteria", ext_info.mapper)
+ ]
+ if ae.include_aliases or ae.entity is ext_info
+ )
+
+ if ext_info.mapper._single_table_criterion is not None:
+ return_crit += (ext_info.mapper._single_table_criterion,)
+
+ if adapter:
+ return_crit = tuple(adapter.traverse(crit) for crit in return_crit)
+
+ return return_crit
+
@classmethod
def _do_pre_synchronize_evaluate(
cls,
try:
evaluator_compiler = evaluator.EvaluatorCompiler(target_cls)
+ crit = ()
if statement._where_criteria:
- eval_condition = evaluator_compiler.process(
- *statement._where_criteria
+ crit += statement._where_criteria
+
+ global_attributes = {}
+ for opt in statement._with_options:
+ if opt._is_criteria_option:
+ opt.get_global_criteria(global_attributes)
+
+ if global_attributes:
+ crit += cls._adjust_for_extra_criteria(
+ global_attributes, mapper
)
+
+ if crit:
+ eval_condition = evaluator_compiler.process(*crit)
else:
def eval_condition(obj):
# TODO: detect when the where clause is a trivial primary key match.
matched_objects = [
- obj
- for (cls, pk, identity_token,), obj in session.identity_map.items()
- if issubclass(cls, target_cls)
- and eval_condition(obj)
+ state.obj()
+ for state in session.identity_map.all_states()
+ if state.mapper.isa(mapper)
+ and eval_condition(state.obj())
and (
update_options._refresh_identity_token is None
# TODO: coverage for the case where horiziontal sharding
# invokes an update() or delete() given an explicit identity
# token up front
- or identity_token == update_options._refresh_identity_token
+ or state.identity_token
+ == update_options._refresh_identity_token
)
]
return update_options + {
):
mapper = update_options._subject_mapper
- select_stmt = select(
- *(mapper.primary_key + (mapper.select_identity_token,))
+ select_stmt = (
+ select(*(mapper.primary_key + (mapper.select_identity_token,)))
+ .select_from(mapper)
+ .options(*statement._with_options)
)
select_stmt._where_criteria = statement._where_criteria
self = cls.__new__(cls)
- self.mapper = mapper = statement.table._annotations.get(
- "parentmapper", None
- )
+ ext_info = statement.table._annotations["parententity"]
+
+ self.mapper = mapper = ext_info.mapper
+
+ self.extra_criteria_entities = {}
self._resolved_values = cls._get_resolved_values(mapper, statement)
+ extra_criteria_attributes = {}
+
+ for opt in statement._with_options:
+ if opt._is_criteria_option:
+ opt.get_global_criteria(extra_criteria_attributes)
+
if not statement._preserve_parameter_order and statement._values:
self._resolved_values = dict(self._resolved_values)
elif statement._values:
new_stmt._values = self._resolved_values
+ new_crit = cls._adjust_for_extra_criteria(
+ extra_criteria_attributes, mapper
+ )
+ if new_crit:
+ new_stmt = new_stmt.where(*new_crit)
+
# if we are against a lambda statement we might not be the
# topmost object that received per-execute annotations
top_level_stmt = compiler.statement
def create_for_statement(cls, statement, compiler, **kw):
self = cls.__new__(cls)
- self.mapper = mapper = statement.table._annotations.get(
- "parentmapper", None
- )
+ ext_info = statement.table._annotations["parententity"]
+ self.mapper = mapper = ext_info.mapper
top_level_stmt = compiler.statement
+
+ self.extra_criteria_entities = {}
+
+ extra_criteria_attributes = {}
+
+ for opt in statement._with_options:
+ if opt._is_criteria_option:
+ opt.get_global_criteria(extra_criteria_attributes)
+
+ new_crit = cls._adjust_for_extra_criteria(
+ extra_criteria_attributes, mapper
+ )
+ if new_crit:
+ statement = statement.where(*new_crit)
+
if (
mapper
and top_level_stmt._annotations.get("synchronize_session", None)
from .base import state_attribute_str # noqa
from .base import state_class_str # noqa
from .base import state_str # noqa
-from .interfaces import LoaderOption
+from .interfaces import CriteriaOption
from .interfaces import MapperProperty # noqa
from .interfaces import ORMColumnsClauseRole
from .interfaces import ORMEntityColumnsClauseRole
return "aliased(%s)" % (self._target.__name__,)
-class LoaderCriteriaOption(LoaderOption):
+class LoaderCriteriaOption(CriteriaOption):
"""Add additional WHERE criteria to the load for all occurrences of
a particular entity.
# if options to limit the criteria to immediate query only,
# use compile_state.attributes instead
+ self.get_global_criteria(compile_state.global_attributes)
+
+ def get_global_criteria(self, attributes):
for mp in self._all_mappers():
- load_criteria = compile_state.global_attributes.setdefault(
+ load_criteria = attributes.setdefault(
("additional_entity_criteria", mp), []
)
from sqlalchemy.types import NullType
from . import coercions
from . import roles
+from .base import _entity_namespace_key
from .base import _from_objects
from .base import _generative
from .base import ColumnCollection
)
def filter(self, *criteria):
- """A synonym for the :meth:`_dml.DMLWhereBase.where` method."""
+ """A synonym for the :meth:`_dml.DMLWhereBase.where` method.
+
+ .. versionadded:: 1.4
+
+ """
return self.where(*criteria)
+ def _filter_by_zero(self):
+ return self.table
+
+ def filter_by(self, **kwargs):
+ r"""apply the given filtering criterion as a WHERE clause
+ to this select.
+
+ """
+ from_entity = self._filter_by_zero()
+
+ clauses = [
+ _entity_namespace_key(from_entity, key) == value
+ for key, value in kwargs.items()
+ ]
+ return self.filter(*clauses)
+
@property
def whereclause(self):
"""Return the completed WHERE clause for this :class:`.DMLWhereBase`
from sqlalchemy.orm import mapper
from sqlalchemy.orm import relationship
from sqlalchemy.orm import Session
+from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm import synonym
+from sqlalchemy.orm import with_loader_criteria
from sqlalchemy.testing import assert_raises
from sqlalchemy.testing import assert_raises_message
from sqlalchemy.testing import eq_
list(zip([15, 27, 19, 27])),
)
+ @testing.combinations(
+ ("fetch", False),
+ ("fetch", True),
+ ("evaluate", False),
+ ("evaluate", True),
+ )
+ def test_update_with_loader_criteria(self, fetchstyle, future):
+ User = self.classes.User
+
+ sess = Session(testing.db, future=True)
+
+ john, jack, jill, jane = (
+ sess.execute(select(User).order_by(User.id)).scalars().all()
+ )
+
+ sess.execute(
+ update(User)
+ .options(
+ with_loader_criteria(User, User.name.in_(["jill", "jane"]))
+ )
+ .where(User.age > 29)
+ .values(age=User.age - 10)
+ .execution_options(synchronize_session=fetchstyle)
+ )
+
+ eq_([john.age, jack.age, jill.age, jane.age], [25, 47, 29, 27])
+ eq_(
+ sess.execute(select(User.age).order_by(User.id)).all(),
+ list(zip([25, 47, 29, 27])),
+ )
+
+ @testing.combinations(
+ ("fetch", False),
+ ("fetch", True),
+ ("evaluate", False),
+ ("evaluate", True),
+ )
+ def test_delete_with_loader_criteria(self, fetchstyle, future):
+ User = self.classes.User
+
+ sess = Session(testing.db, future=True)
+
+ john, jack, jill, jane = (
+ sess.execute(select(User).order_by(User.id)).scalars().all()
+ )
+
+ sess.execute(
+ delete(User)
+ .options(
+ with_loader_criteria(User, User.name.in_(["jill", "jane"]))
+ )
+ .where(User.age > 29)
+ .execution_options(synchronize_session=fetchstyle)
+ )
+
+ assert jane not in sess
+ assert jack in sess
+ eq_(
+ sess.execute(select(User.age).order_by(User.id)).all(),
+ list(zip([25, 47, 29])),
+ )
+
def test_update_against_table_col(self):
User, users = self.classes.User, self.tables.users
set(s.query(Person.name, Engineer.engineer_name)),
set([("e1", "e1"), ("e22", "e55")]),
)
+
+
+class SingleTablePolymorphicTest(fixtures.DeclarativeMappedTest):
+ __backend__ = True
+
+ @classmethod
+ def setup_classes(cls):
+ Base = cls.DeclarativeBasic
+
+ class Staff(Base):
+ __tablename__ = "staff"
+ position = Column(String(10), nullable=False)
+ id = Column(
+ Integer, primary_key=True, test_needs_autoincrement=True
+ )
+ name = Column(String(5))
+ stats = Column(String(5))
+ __mapper_args__ = {"polymorphic_on": position}
+
+ class Sales(Staff):
+ sales_stats = Column(String(5))
+ __mapper_args__ = {"polymorphic_identity": "sales"}
+
+ class Support(Staff):
+ support_stats = Column(String(5))
+ __mapper_args__ = {"polymorphic_identity": "support"}
+
+ @classmethod
+ def insert_data(cls, connection):
+ with sessionmaker(connection).begin() as session:
+ Sales, Support = (
+ cls.classes.Sales,
+ cls.classes.Support,
+ )
+ session.add_all(
+ [
+ Sales(name="n1", sales_stats="1", stats="a"),
+ Sales(name="n2", sales_stats="2", stats="b"),
+ Support(name="n1", support_stats="3", stats="c"),
+ Support(name="n2", support_stats="4", stats="d"),
+ ]
+ )
+
+ @testing.combinations(
+ ("fetch", False),
+ ("fetch", True),
+ ("evaluate", False),
+ ("evaluate", True),
+ )
+ def test_update(self, fetchstyle, future):
+ Staff, Sales, Support = self.classes("Staff", "Sales", "Support")
+
+ sess = Session()
+
+ en1, en2 = (
+ sess.execute(select(Sales).order_by(Sales.sales_stats))
+ .scalars()
+ .all()
+ )
+ mn1, mn2 = (
+ sess.execute(select(Support).order_by(Support.support_stats))
+ .scalars()
+ .all()
+ )
+
+ if future:
+ sess.execute(
+ update(Sales)
+ .filter_by(name="n1")
+ .values(stats="p")
+ .execution_options(synchronize_session=fetchstyle)
+ )
+ else:
+ sess.query(Sales).filter_by(name="n1").update(
+ {"stats": "p"}, synchronize_session=fetchstyle
+ )
+
+ eq_(en1.stats, "p")
+ eq_(mn1.stats, "c")
+ eq_(
+ sess.execute(
+ select(Staff.position, Staff.name, Staff.stats).order_by(
+ Staff.id
+ )
+ ).all(),
+ [
+ ("sales", "n1", "p"),
+ ("sales", "n2", "b"),
+ ("support", "n1", "c"),
+ ("support", "n2", "d"),
+ ],
+ )
+
+ @testing.combinations(
+ ("fetch", False),
+ ("fetch", True),
+ ("evaluate", False),
+ ("evaluate", True),
+ )
+ def test_delete(self, fetchstyle, future):
+ Staff, Sales, Support = self.classes("Staff", "Sales", "Support")
+
+ sess = Session()
+ en1, en2 = sess.query(Sales).order_by(Sales.sales_stats).all()
+ mn1, mn2 = sess.query(Support).order_by(Support.support_stats).all()
+
+ if future:
+ sess.execute(
+ delete(Sales)
+ .filter_by(name="n1")
+ .execution_options(synchronize_session=fetchstyle)
+ )
+ else:
+ sess.query(Sales).filter_by(name="n1").delete(
+ synchronize_session=fetchstyle
+ )
+ assert en1 not in sess
+ assert en2 in sess
+ assert mn1 in sess
+ assert mn2 in sess
+
+ eq_(
+ sess.execute(
+ select(Staff.position, Staff.name, Staff.stats).order_by(
+ Staff.id
+ )
+ ).all(),
+ [
+ ("sales", "n2", "b"),
+ ("support", "n1", "c"),
+ ("support", "n2", "d"),
+ ],
+ )