]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
process bulk_update_tuples before cache key or compilation
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 19 Oct 2021 18:07:32 +0000 (14:07 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 19 Oct 2021 19:10:14 +0000 (15:10 -0400)
Fixed regression where the use of a :class:`_orm.hybrid_property` attribute
or a mapped :func:`_orm.composite` attribute as a key passed to the
:meth:`_dml.Update.values` method for an ORM-enabled :class:`_dml.Update`
statement, as well as when using it via the legacy
:meth:`_orm.Query.update` method, would be processed for incoming
ORM/hybrid/composite values within the compilation stage of the UPDATE
statement, which meant that in those cases where caching occurred,
subsequent invocations of the same statement would no longer receive the
correct values. This would include not only hybrids that use the
:meth:`_orm.hybrid_property.update_expression` method, but any use of a
plain hybrid attribute as well. For composites, the issue instead caused a
non-repeatable cache key to be generated, which would break caching and
could fill up the statement cache with repeated statements.

The :class:`_dml.Update` construct now handles the processing of key/value
pairs passed to :meth:`_dml.Update.values` and
:meth:`_dml.Update.ordered_values` up front when the construct is first
generated, before the cache key has been generated so that the key/value
pairs are processed each time, and so that the cache key is generated
against the individual column/value pairs that will ultimately be
used in the statement.

Fixes: #7209
Change-Id: I08f248d1d60ea9690b014c21439b775d951fb9e5

12 files changed:
doc/build/changelog/unreleased_14/7209.rst [new file with mode: 0644]
lib/sqlalchemy/ext/hybrid.py
lib/sqlalchemy/orm/persistence.py
lib/sqlalchemy/sql/base.py
lib/sqlalchemy/sql/dml.py
lib/sqlalchemy/sql/elements.py
test/ext/test_hybrid.py
test/orm/test_cache_key.py
test/orm/test_composites.py
test/orm/test_update_delete.py
test/requirements.py
test/sql/test_update.py

diff --git a/doc/build/changelog/unreleased_14/7209.rst b/doc/build/changelog/unreleased_14/7209.rst
new file mode 100644 (file)
index 0000000..9ae00e7
--- /dev/null
@@ -0,0 +1,26 @@
+.. change::
+    :tags: bug, orm, regression
+    :tickets: 7209
+
+    Fixed regression where the use of a :class:`_orm.hybrid_property` attribute
+    or a mapped :func:`_orm.composite` attribute as a key passed to the
+    :meth:`_dml.Update.values` method for an ORM-enabled :class:`_dml.Update`
+    statement, as well as when using it via the legacy
+    :meth:`_orm.Query.update` method, would be processed for incoming
+    ORM/hybrid/composite values within the compilation stage of the UPDATE
+    statement, which meant that in those cases where caching occurred,
+    subsequent invocations of the same statement would no longer receive the
+    correct values. This would include not only hybrids that use the
+    :meth:`_orm.hybrid_property.update_expression` method, but any use of a
+    plain hybrid attribute as well. For composites, the issue instead caused a
+    non-repeatable cache key to be generated, which would break caching and
+    could fill up the statement cache with repeated statements.
+
+    The :class:`_dml.Update` construct now handles the processing of key/value
+    pairs passed to :meth:`_dml.Update.values` and
+    :meth:`_dml.Update.ordered_values` up front when the construct is first
+    generated, before the cache key has been generated so that the key/value
+    pairs are processed each time, and so that the cache key is generated
+    against the individual column/value pairs that will ultimately be
+    used in the statement.
+
index 298d957f69fd8d5757de3b23fae62f02151d5632..eab3f2b7385707ae903217bfafb74f3fec5057de 100644 (file)
@@ -805,7 +805,6 @@ things it can be used for.
 from .. import util
 from ..orm import attributes
 from ..orm import interfaces
-from ..sql import elements
 
 HYBRID_METHOD = util.symbol("HYBRID_METHOD")
 """Symbol indicating an :class:`InspectionAttr` that's
@@ -1183,9 +1182,6 @@ class ExprComparator(Comparator):
         return self.hybrid.info
 
     def _bulk_update_tuples(self, value):
-        if isinstance(value, elements.BindParameter):
-            value = value.value
-
         if isinstance(self.expression, attributes.QueryableAttribute):
             return self.expression._bulk_update_tuples(value)
         elif self.hybrid.update_expr is not None:
index fd484b52b30df3d690df97aefb0f5480bba94037..3d20cfdea076986edddca885ac9d93cde41ec965 100644 (file)
@@ -23,6 +23,7 @@ from . import evaluator
 from . import exc as orm_exc
 from . import loading
 from . import sync
+from .base import NO_VALUE
 from .base import state_str
 from .. import exc as sa_exc
 from .. import future
@@ -34,6 +35,7 @@ from ..sql import expression
 from ..sql import operators
 from ..sql import roles
 from ..sql import select
+from ..sql import sqltypes
 from ..sql.base import _entity_namespace_key
 from ..sql.base import CompileState
 from ..sql.base import Options
@@ -2002,31 +2004,12 @@ class BulkUDCompileState(CompileState):
         if statement._multi_values:
             return []
         elif statement._ordered_values:
-            iterator = statement._ordered_values
+            return list(statement._ordered_values)
         elif statement._values:
-            iterator = statement._values.items()
+            return list(statement._values.items())
         else:
             return []
 
-        values = []
-        if iterator:
-            for k, v in iterator:
-                if mapper:
-                    if isinstance(k, util.string_types):
-                        desc = _entity_namespace_key(mapper, k)
-                        values.extend(desc._bulk_update_tuples(v))
-                    elif "entity_namespace" in k._annotations:
-                        k_anno = k._annotations
-                        attr = _entity_namespace_key(
-                            k_anno["entity_namespace"], k_anno["proxy_key"]
-                        )
-                        values.extend(attr._bulk_update_tuples(v))
-                    else:
-                        values.append((k, v))
-                else:
-                    values.append((k, v))
-        return values
-
     @classmethod
     def _resolved_keys_as_propnames(cls, mapper, resolved_values):
         values = []
@@ -2190,6 +2173,68 @@ class BulkORMUpdate(UpdateDMLState, BulkUDCompileState):
 
         return self
 
+    @classmethod
+    def _get_crud_kv_pairs(cls, statement, kv_iterator):
+        plugin_subject = statement._propagate_attrs["plugin_subject"]
+
+        if plugin_subject:
+            mapper = plugin_subject.mapper
+        else:
+            mapper = None
+
+        values = []
+        core_get_crud_kv_pairs = UpdateDMLState._get_crud_kv_pairs
+
+        for k, v in kv_iterator:
+            if mapper:
+                k = coercions.expect(roles.DMLColumnRole, k)
+
+                if isinstance(k, util.string_types):
+                    desc = _entity_namespace_key(mapper, k, default=NO_VALUE)
+                    if desc is NO_VALUE:
+                        values.append(
+                            (
+                                k,
+                                coercions.expect(
+                                    roles.ExpressionElementRole,
+                                    v,
+                                    type_=sqltypes.NullType(),
+                                    is_crud=True,
+                                ),
+                            )
+                        )
+                    else:
+                        values.extend(
+                            core_get_crud_kv_pairs(
+                                statement, desc._bulk_update_tuples(v)
+                            )
+                        )
+                elif "entity_namespace" in k._annotations:
+                    k_anno = k._annotations
+                    attr = _entity_namespace_key(
+                        k_anno["entity_namespace"], k_anno["proxy_key"]
+                    )
+                    values.extend(
+                        core_get_crud_kv_pairs(
+                            statement, attr._bulk_update_tuples(v)
+                        )
+                    )
+                else:
+                    values.append(
+                        (
+                            k,
+                            coercions.expect(
+                                roles.ExpressionElementRole,
+                                v,
+                                type_=sqltypes.NullType(),
+                                is_crud=True,
+                            ),
+                        )
+                    )
+            else:
+                values.extend(core_get_crud_kv_pairs(statement, [(k, v)]))
+        return values
+
     @classmethod
     def _do_post_synchronize_evaluate(cls, session, result, update_options):
 
index b235f5132f0dfa961ccd2ce39e7e1972ca59d73f..aba80222a6b8bbd379360da6a78a77abb96bff05 100644 (file)
@@ -515,12 +515,20 @@ class CompileState(object):
     @classmethod
     def get_plugin_class(cls, statement):
         plugin_name = statement._propagate_attrs.get(
-            "compile_state_plugin", "default"
+            "compile_state_plugin", None
         )
+
+        if plugin_name:
+            key = (plugin_name, statement._effective_plugin_target)
+            if key in cls.plugins:
+                return cls.plugins[key]
+
+        # there's no case where we call upon get_plugin_class() and want
+        # to get None back, there should always be a default.  return that
+        # if there was no plugin-specific class  (e.g. Insert with "orm"
+        # plugin)
         try:
-            return cls.plugins[
-                (plugin_name, statement._effective_plugin_target)
-            ]
+            return cls.plugins[("default", statement._effective_plugin_target)]
         except KeyError:
             return None
 
@@ -1665,7 +1673,7 @@ def _entity_namespace(entity):
             raise
 
 
-def _entity_namespace_key(entity, key):
+def _entity_namespace_key(entity, key, default=NO_ARG):
     """Return an entry from an entity_namespace.
 
 
@@ -1676,7 +1684,10 @@ def _entity_namespace_key(entity, key):
 
     try:
         ns = _entity_namespace(entity)
-        return getattr(ns, key)
+        if default is not NO_ARG:
+            return getattr(ns, key, default)
+        else:
+            return getattr(ns, key)
     except AttributeError as err:
         util.raise_(
             exc.InvalidRequestError(
index 158cb40f2773e5a14510803895a204255f470cfc..ebff0df88d163081d36615170c786685a909bf1b 100644 (file)
@@ -52,6 +52,21 @@ class DMLState(CompileState):
     def dml_table(self):
         return self.statement.table
 
+    @classmethod
+    def _get_crud_kv_pairs(cls, statement, kv_iterator):
+        return [
+            (
+                coercions.expect(roles.DMLColumnRole, k),
+                coercions.expect(
+                    roles.ExpressionElementRole,
+                    v,
+                    type_=NullType(),
+                    is_crud=True,
+                ),
+            )
+            for k, v in kv_iterator
+        ]
+
     def _make_extra_froms(self, statement):
         froms = []
 
@@ -674,30 +689,12 @@ class ValuesBase(UpdateBase):
         # crud.py now intercepts bound parameters with unique=True from here
         # and ensures they get the "crud"-style name when rendered.
 
+        kv_generator = DMLState.get_plugin_class(self)._get_crud_kv_pairs
+
         if self._preserve_parameter_order:
-            arg = [
-                (
-                    coercions.expect(roles.DMLColumnRole, k),
-                    coercions.expect(
-                        roles.ExpressionElementRole,
-                        v,
-                        type_=NullType(),
-                        is_crud=True,
-                    ),
-                )
-                for k, v in arg
-            ]
-            self._ordered_values = arg
+            self._ordered_values = kv_generator(self, arg)
         else:
-            arg = {
-                coercions.expect(roles.DMLColumnRole, k): coercions.expect(
-                    roles.ExpressionElementRole,
-                    v,
-                    type_=NullType(),
-                    is_crud=True,
-                )
-                for k, v in arg.items()
-            }
+            arg = {k: v for k, v in kv_generator(self, arg.items())}
             if self._values:
                 self._values = self._values.union(arg)
             else:
@@ -1319,19 +1316,9 @@ class Update(DMLWhereBase, ValuesBase):
             raise exc.ArgumentError(
                 "This statement already has ordered values present"
             )
-        arg = [
-            (
-                coercions.expect(roles.DMLColumnRole, k),
-                coercions.expect(
-                    roles.ExpressionElementRole,
-                    v,
-                    type_=NullType(),
-                    is_crud=True,
-                ),
-            )
-            for k, v in args
-        ]
-        self._ordered_values = arg
+
+        kv_generator = DMLState.get_plugin_class(self)._get_crud_kv_pairs
+        self._ordered_values = kv_generator(self, args)
 
     @_generative
     def inline(self):
index 3699f872bb4d9618726387030cc511277d508016..ae105428c3c8af940f92af76809cac00f57e0d40 100644 (file)
@@ -1493,7 +1493,6 @@ class BindParameter(roles.InElementRole, ColumnElement):
                 :ref:`change_4808`.
 
         """
-
         if required is NO_ARG:
             required = value is NO_ARG and callable_ is None
         if value is NO_ARG:
index f0bb87055a56b7e0aa5c1853993eb290ac694737..40cca5266295ec908ff49114625458abdbaebeff 100644 (file)
@@ -945,6 +945,34 @@ class BulkUpdateTest(fixtures.DeclarativeMappedTest, AssertsCompiledSQL):
             params={"first_name": "Dr.", "last_name": "No"},
         )
 
+    # these tests all run two UPDATES to assert that caching is not
+    # interfering.  this is #7209
+
+    def test_evaluate_non_hybrid_attr(self):
+        # this is a control case
+        Person = self.classes.Person
+
+        s = fixture_session()
+        jill = s.query(Person).get(3)
+
+        s.query(Person).update(
+            {Person.first_name: "moonbeam"}, synchronize_session="evaluate"
+        )
+        eq_(jill.first_name, "moonbeam")
+        eq_(
+            s.scalar(select(Person.first_name).where(Person.id == 3)),
+            "moonbeam",
+        )
+
+        s.query(Person).update(
+            {Person.first_name: "sunshine"}, synchronize_session="evaluate"
+        )
+        eq_(jill.first_name, "sunshine")
+        eq_(
+            s.scalar(select(Person.first_name).where(Person.id == 3)),
+            "sunshine",
+        )
+
     def test_evaluate_hybrid_attr_indirect(self):
         Person = self.classes.Person
 
@@ -955,6 +983,19 @@ class BulkUpdateTest(fixtures.DeclarativeMappedTest, AssertsCompiledSQL):
             {Person.fname2: "moonbeam"}, synchronize_session="evaluate"
         )
         eq_(jill.fname2, "moonbeam")
+        eq_(
+            s.scalar(select(Person.first_name).where(Person.id == 3)),
+            "moonbeam",
+        )
+
+        s.query(Person).update(
+            {Person.fname2: "sunshine"}, synchronize_session="evaluate"
+        )
+        eq_(jill.fname2, "sunshine")
+        eq_(
+            s.scalar(select(Person.first_name).where(Person.id == 3)),
+            "sunshine",
+        )
 
     def test_evaluate_hybrid_attr_plain(self):
         Person = self.classes.Person
@@ -966,6 +1007,19 @@ class BulkUpdateTest(fixtures.DeclarativeMappedTest, AssertsCompiledSQL):
             {Person.fname: "moonbeam"}, synchronize_session="evaluate"
         )
         eq_(jill.fname, "moonbeam")
+        eq_(
+            s.scalar(select(Person.first_name).where(Person.id == 3)),
+            "moonbeam",
+        )
+
+        s.query(Person).update(
+            {Person.fname: "sunshine"}, synchronize_session="evaluate"
+        )
+        eq_(jill.fname, "sunshine")
+        eq_(
+            s.scalar(select(Person.first_name).where(Person.id == 3)),
+            "sunshine",
+        )
 
     def test_fetch_hybrid_attr_indirect(self):
         Person = self.classes.Person
@@ -977,6 +1031,19 @@ class BulkUpdateTest(fixtures.DeclarativeMappedTest, AssertsCompiledSQL):
             {Person.fname2: "moonbeam"}, synchronize_session="fetch"
         )
         eq_(jill.fname2, "moonbeam")
+        eq_(
+            s.scalar(select(Person.first_name).where(Person.id == 3)),
+            "moonbeam",
+        )
+
+        s.query(Person).update(
+            {Person.fname2: "sunshine"}, synchronize_session="fetch"
+        )
+        eq_(jill.fname2, "sunshine")
+        eq_(
+            s.scalar(select(Person.first_name).where(Person.id == 3)),
+            "sunshine",
+        )
 
     def test_fetch_hybrid_attr_plain(self):
         Person = self.classes.Person
@@ -988,6 +1055,19 @@ class BulkUpdateTest(fixtures.DeclarativeMappedTest, AssertsCompiledSQL):
             {Person.fname: "moonbeam"}, synchronize_session="fetch"
         )
         eq_(jill.fname, "moonbeam")
+        eq_(
+            s.scalar(select(Person.first_name).where(Person.id == 3)),
+            "moonbeam",
+        )
+
+        s.query(Person).update(
+            {Person.fname: "sunshine"}, synchronize_session="fetch"
+        )
+        eq_(jill.fname, "sunshine")
+        eq_(
+            s.scalar(select(Person.first_name).where(Person.id == 3)),
+            "sunshine",
+        )
 
     def test_evaluate_hybrid_attr_w_update_expr(self):
         Person = self.classes.Person
@@ -999,6 +1079,16 @@ class BulkUpdateTest(fixtures.DeclarativeMappedTest, AssertsCompiledSQL):
             {Person.name: "moonbeam sunshine"}, synchronize_session="evaluate"
         )
         eq_(jill.name, "moonbeam sunshine")
+        eq_(
+            s.scalar(select(Person.first_name).where(Person.id == 3)),
+            "moonbeam",
+        )
+
+        s.query(Person).update(
+            {Person.name: "first last"}, synchronize_session="evaluate"
+        )
+        eq_(jill.name, "first last")
+        eq_(s.scalar(select(Person.first_name).where(Person.id == 3)), "first")
 
     def test_fetch_hybrid_attr_w_update_expr(self):
         Person = self.classes.Person
@@ -1010,6 +1100,16 @@ class BulkUpdateTest(fixtures.DeclarativeMappedTest, AssertsCompiledSQL):
             {Person.name: "moonbeam sunshine"}, synchronize_session="fetch"
         )
         eq_(jill.name, "moonbeam sunshine")
+        eq_(
+            s.scalar(select(Person.first_name).where(Person.id == 3)),
+            "moonbeam",
+        )
+
+        s.query(Person).update(
+            {Person.name: "first last"}, synchronize_session="fetch"
+        )
+        eq_(jill.name, "first last")
+        eq_(s.scalar(select(Person.first_name).where(Person.id == 3)), "first")
 
     def test_evaluate_hybrid_attr_indirect_w_update_expr(self):
         Person = self.classes.Person
@@ -1021,6 +1121,16 @@ class BulkUpdateTest(fixtures.DeclarativeMappedTest, AssertsCompiledSQL):
             {Person.uname: "moonbeam sunshine"}, synchronize_session="evaluate"
         )
         eq_(jill.uname, "moonbeam sunshine")
+        eq_(
+            s.scalar(select(Person.first_name).where(Person.id == 3)),
+            "moonbeam",
+        )
+
+        s.query(Person).update(
+            {Person.uname: "first last"}, synchronize_session="evaluate"
+        )
+        eq_(jill.uname, "first last")
+        eq_(s.scalar(select(Person.first_name).where(Person.id == 3)), "first")
 
 
 class SpecialObjectTest(fixtures.TestBase, AssertsCompiledSQL):
index 3c6536195a7eb9e738247e10c87d11d69a20cf94..f25a57fe5a9b515b11e5b5b6d8af1307968104f1 100644 (file)
@@ -1,12 +1,17 @@
 import random
 
+import sqlalchemy as sa
+from sqlalchemy import Column
 from sqlalchemy import func
 from sqlalchemy import inspect
+from sqlalchemy import Integer
 from sqlalchemy import null
 from sqlalchemy import select
+from sqlalchemy import Table
 from sqlalchemy import testing
 from sqlalchemy import text
 from sqlalchemy import true
+from sqlalchemy import update
 from sqlalchemy.orm import aliased
 from sqlalchemy.orm import Bundle
 from sqlalchemy.orm import defaultload
@@ -29,6 +34,7 @@ from sqlalchemy.sql.expression import case
 from sqlalchemy.sql.visitors import InternalTraversal
 from sqlalchemy.testing import AssertsCompiledSQL
 from sqlalchemy.testing import eq_
+from sqlalchemy.testing import fixtures
 from sqlalchemy.testing import ne_
 from sqlalchemy.testing.fixtures import fixture_session
 from test.orm import _fixtures
@@ -884,3 +890,74 @@ class RoundTripTest(QueryTest, AssertsCompiledSQL):
             go()
 
         eq_(len(cache), lc)
+
+
+class CompositeTest(fixtures.MappedTest):
+    __dialect__ = "default"
+
+    @classmethod
+    def define_tables(cls, metadata):
+        Table(
+            "edges",
+            metadata,
+            Column("id", Integer, primary_key=True),
+            Column("x1", Integer),
+            Column("y1", Integer),
+            Column("x2", Integer),
+            Column("y2", Integer),
+        )
+
+    @classmethod
+    def setup_mappers(cls):
+        edges = cls.tables.edges
+
+        class Point(cls.Comparable):
+            def __init__(self, x, y):
+                self.x = x
+                self.y = y
+
+            def __composite_values__(self):
+                return [self.x, self.y]
+
+            __hash__ = None
+
+            def __eq__(self, other):
+                return (
+                    isinstance(other, Point)
+                    and other.x == self.x
+                    and other.y == self.y
+                )
+
+            def __ne__(self, other):
+                return not isinstance(other, Point) or not self.__eq__(other)
+
+        class Edge(cls.Comparable):
+            def __init__(self, *args):
+                if args:
+                    self.start, self.end = args
+
+        cls.mapper_registry.map_imperatively(
+            Edge,
+            edges,
+            properties={
+                "start": sa.orm.composite(Point, edges.c.x1, edges.c.y1),
+                "end": sa.orm.composite(Point, edges.c.x2, edges.c.y2),
+            },
+        )
+
+    def test_bulk_update_cache_key(self):
+        """test secondary issue located as part of #7209"""
+        Edge, Point = (self.classes.Edge, self.classes.Point)
+
+        stmt = (
+            update(Edge)
+            .filter(Edge.start == Point(14, 5))
+            .values({Edge.end: Point(16, 10)})
+        )
+        stmt2 = (
+            update(Edge)
+            .filter(Edge.start == Point(14, 5))
+            .values({Edge.end: Point(17, 8)})
+        )
+
+        eq_(stmt._generate_cache_key(), stmt2._generate_cache_key())
index 2f3b9a70e4005c892fab728804d49018ae8079db..4bdca7a45ac90c4e639f6b2468f1e34d97791f71 100644 (file)
@@ -273,6 +273,15 @@ class PointTest(fixtures.MappedTest, testing.AssertsCompiledSQL):
 
         eq_(e1.end, Point(16, 10))
 
+        stmt = (
+            update(Edge)
+            .filter(Edge.start == Point(14, 5))
+            .values({Edge.end: Point(17, 8)})
+        )
+        sess.execute(stmt)
+
+        eq_(e1.end, Point(17, 8))
+
     def test_bulk_update_fetch(self):
         Edge, Point = (self.classes.Edge, self.classes.Point)
 
@@ -287,6 +296,10 @@ class PointTest(fixtures.MappedTest, testing.AssertsCompiledSQL):
 
         eq_(e1.end, Point(16, 10))
 
+        q.update({Edge.end: Point(17, 8)}, synchronize_session="fetch")
+
+        eq_(e1.end, Point(17, 8))
+
     def test_get_history(self):
         Edge = self.classes.Edge
         Point = self.classes.Point
index d6806d9bdf6a571ae38c9a9c7064cfa101e616ac..a582862861515385c34f900192280981cbc8c6ab 100644 (file)
@@ -1674,11 +1674,37 @@ class UpdateDeleteFromTest(fixtures.MappedTest):
         s = fixture_session()
 
         q = s.query(User).filter(User.id == Document.user_id)
+
         assert_raises_message(
             exc.InvalidRequestError,
             "Could not evaluate current criteria in Python.",
             q.update,
-            {"name": "ed"},
+            {"samename": "ed"},
+        )
+
+    @testing.requires.multi_table_update
+    def test_multi_table_criteria_ok_wo_eval(self):
+        User = self.classes.User
+        Document = self.classes.Document
+
+        s = fixture_session()
+
+        q = s.query(User).filter(User.id == Document.user_id)
+
+        q.update({Document.samename: "ed"}, synchronize_session="fetch")
+        eq_(
+            s.query(User.id, Document.samename, User.samename)
+            .filter(User.id == Document.user_id)
+            .order_by(User.id)
+            .all(),
+            [
+                (1, "ed", None),
+                (1, "ed", None),
+                (2, "ed", None),
+                (2, "ed", None),
+                (3, "ed", None),
+                (3, "ed", None),
+            ],
         )
 
     @testing.requires.update_where_target_in_subquery
@@ -1744,7 +1770,7 @@ class UpdateDeleteFromTest(fixtures.MappedTest):
             ),
         )
 
-    @testing.only_on("mysql", "Multi table update")
+    @testing.requires.multi_table_update
     def test_update_from_multitable_same_names(self):
         Document = self.classes.Document
         User = self.classes.User
index 6d65c2976a5f47921df2c49d5012d587b99da1e3..687dadfd1aa47e010c9d2522b93eb280ff754ef7 100644 (file)
@@ -508,6 +508,10 @@ class DefaultRequirements(SuiteRequirements):
             'outer-joined to a subquery"',
         )
 
+    @property
+    def multi_table_update(self):
+        return only_on(["mysql", "mariadb"], "Multi table update")
+
     @property
     def update_from(self):
         """Target must support UPDATE..FROM syntax"""
index 8004e6a4cfe81f9832f672e01bb152af232c2777..93deae5565edb7749fb4e7347ae1b7aed59e067e 100644 (file)
@@ -1328,7 +1328,7 @@ class UpdateFromRoundTripTest(_UpdateFromTestBase, fixtures.TablesTest):
         ]
         self._assert_addresses(connection, addresses, expected)
 
-    @testing.only_on("mysql", "Multi table update")
+    @testing.requires.multi_table_update
     def test_exec_multitable(self, connection):
         users, addresses = self.tables.users, self.tables.addresses
 
@@ -1353,7 +1353,7 @@ class UpdateFromRoundTripTest(_UpdateFromTestBase, fixtures.TablesTest):
         expected = [(7, "jack"), (8, "ed2"), (9, "fred"), (10, "chuck")]
         self._assert_users(connection, users, expected)
 
-    @testing.only_on("mysql", "Multi table update")
+    @testing.requires.multi_table_update
     def test_exec_join_multitable(self, connection):
         users, addresses = self.tables.users, self.tables.addresses
 
@@ -1377,7 +1377,7 @@ class UpdateFromRoundTripTest(_UpdateFromTestBase, fixtures.TablesTest):
         expected = [(7, "jack"), (8, "ed2"), (9, "fred"), (10, "chuck")]
         self._assert_users(connection, users, expected)
 
-    @testing.only_on("mysql", "Multi table update")
+    @testing.requires.multi_table_update
     def test_exec_multitable_same_name(self, connection):
         users, addresses = self.tables.users, self.tables.addresses
 
@@ -1471,7 +1471,7 @@ class UpdateFromMultiTableUpdateDefaultsTest(
             ),
         )
 
-    @testing.only_on("mysql", "Multi table update")
+    @testing.requires.multi_table_update
     def test_defaults_second_table(self, connection):
         users, addresses = self.tables.users, self.tables.addresses
 
@@ -1496,7 +1496,7 @@ class UpdateFromMultiTableUpdateDefaultsTest(
         expected = [(8, "ed2", "im the update"), (9, "fred", "value")]
         self._assert_users(connection, users, expected)
 
-    @testing.only_on("mysql", "Multi table update")
+    @testing.requires.multi_table_update
     def test_defaults_second_table_same_name(self, connection):
         users, foobar = self.tables.users, self.tables.foobar
 
@@ -1524,7 +1524,7 @@ class UpdateFromMultiTableUpdateDefaultsTest(
         expected = [(8, "ed2", "im the update"), (9, "fred", "value")]
         self._assert_users(connection, users, expected)
 
-    @testing.only_on("mysql", "Multi table update")
+    @testing.requires.multi_table_update
     def test_no_defaults_second_table(self, connection):
         users, addresses = self.tables.users, self.tables.addresses