]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
implement mysql limit() for UPDATE/DELETE DML (patch 2)
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 24 Feb 2025 20:10:54 +0000 (15:10 -0500)
committerMichael Bayer <mike_mp@zzzcomputing.com>
Thu, 6 Mar 2025 14:55:40 +0000 (14:55 +0000)
Added new construct :func:`_mysql.limit` which can be applied to any
:func:`_sql.update` or :func:`_sql.delete` to provide the LIMIT keyword to
UPDATE and DELETE.  This new construct supersedes the use of the
"mysql_limit" dialect keyword argument.

Change-Id: Ie10c2f273432b0c8881a48f5b287f0566dde6ec3

doc/build/changelog/unreleased_21/mysql_limit.rst [new file with mode: 0644]
lib/sqlalchemy/dialects/mysql/__init__.py
lib/sqlalchemy/dialects/mysql/base.py
lib/sqlalchemy/dialects/mysql/dml.py
lib/sqlalchemy/sql/compiler.py
test/dialect/mysql/test_compiler.py
test/dialect/mysql/test_query.py

diff --git a/doc/build/changelog/unreleased_21/mysql_limit.rst b/doc/build/changelog/unreleased_21/mysql_limit.rst
new file mode 100644 (file)
index 0000000..cf74e97
--- /dev/null
@@ -0,0 +1,8 @@
+.. change::
+    :tags: feature, mysql
+
+    Added new construct :func:`_mysql.limit` which can be applied to any
+    :func:`_sql.update` or :func:`_sql.delete` to provide the LIMIT keyword to
+    UPDATE and DELETE.  This new construct supersedes the use of the
+    "mysql_limit" dialect keyword argument.
+
index 9174c54413a00922e07ddd41e9084b28672b9612..d722c1d30ca31bc159275106ee4a601de62827bc 100644 (file)
@@ -52,6 +52,7 @@ from .base import VARCHAR
 from .base import YEAR
 from .dml import Insert
 from .dml import insert
+from .dml import limit
 from .expression import match
 from .mariadb import INET4
 from .mariadb import INET6
index b57a1e134371cb765152e347ec38e15d63b6352a..7838b455b92a35b059b4c3a7252c6d1e264ed5ee 100644 (file)
@@ -511,16 +511,25 @@ available.
 
     select(...).prefix_with(["HIGH_PRIORITY", "SQL_SMALL_RESULT"])
 
-* UPDATE with LIMIT::
+* UPDATE
+  with LIMIT::
+
+    from sqlalchemy.dialects.mysql import limit
+
+    update(...).ext(limit(10))
 
-    update(...).with_dialect_options(mysql_limit=10, mariadb_limit=10)
+  .. versionchanged:: 2.1 the :func:`_mysql.limit()` extension supersedes the
+     previous use of ``mysql_limit``
 
 * DELETE
   with LIMIT::
 
-    delete(...).with_dialect_options(mysql_limit=10, mariadb_limit=10)
+    from sqlalchemy.dialects.mysql import limit
 
-  .. versionadded:: 2.0.37 Added delete with limit
+    delete(...).ext(limit(10))
+
+  .. versionchanged:: 2.1 the :func:`_mysql.limit()` extension supersedes the
+     previous use of ``mysql_limit``
 
 * optimizer hints, use :meth:`_expression.Select.prefix_with` and
   :meth:`_query.Query.prefix_with`::
@@ -1750,19 +1759,35 @@ class MySQLCompiler(compiler.SQLCompiler):
             # No offset provided, so just use the limit
             return " \n LIMIT %s" % (self.process(limit_clause, **kw),)
 
-    def update_limit_clause(self, update_stmt):
+    def update_post_criteria_clause(self, update_stmt, **kw):
         limit = update_stmt.kwargs.get("%s_limit" % self.dialect.name, None)
+        supertext = super().update_post_criteria_clause(update_stmt, **kw)
+
         if limit is not None:
-            return f"LIMIT {int(limit)}"
+            limit_text = f"LIMIT {int(limit)}"
+            if supertext is not None:
+                return f"{limit_text} {supertext}"
+            else:
+                return limit_text
         else:
-            return None
+            return supertext
 
-    def delete_limit_clause(self, delete_stmt):
+    def delete_post_criteria_clause(self, delete_stmt, **kw):
         limit = delete_stmt.kwargs.get("%s_limit" % self.dialect.name, None)
+        supertext = super().delete_post_criteria_clause(delete_stmt, **kw)
+
         if limit is not None:
-            return f"LIMIT {int(limit)}"
+            limit_text = f"LIMIT {int(limit)}"
+            if supertext is not None:
+                return f"{limit_text} {supertext}"
+            else:
+                return limit_text
         else:
-            return None
+            return supertext
+
+    def visit_mysql_dml_limit_clause(self, element, **kw):
+        kw["literal_execute"] = True
+        return f"LIMIT {self.process(element._limit_clause, **kw)}"
 
     def update_tables_clause(self, update_stmt, from_table, extra_froms, **kw):
         kw["asfrom"] = True
index cceb0818f9b1fb308e52dee3ea9b53bd50815fba..f3be3c395d28f8ceb339bffe8f1c8dfce3063d4f 100644 (file)
@@ -12,26 +12,76 @@ from typing import List
 from typing import Mapping
 from typing import Optional
 from typing import Tuple
+from typing import TYPE_CHECKING
 from typing import Union
 
 from ... import exc
 from ... import util
+from ...sql import coercions
+from ...sql import roles
 from ...sql._typing import _DMLTableArgument
 from ...sql.base import _exclusive_against
 from ...sql.base import _generative
 from ...sql.base import ColumnCollection
 from ...sql.base import ReadOnlyColumnCollection
+from ...sql.base import SyntaxExtension
 from ...sql.dml import Insert as StandardInsert
 from ...sql.elements import ClauseElement
 from ...sql.elements import KeyedColumnElement
 from ...sql.expression import alias
 from ...sql.selectable import NamedFromClause
+from ...sql.visitors import InternalTraversal
 from ...util.typing import Self
 
+if TYPE_CHECKING:
+    from ...sql._typing import _LimitOffsetType
+    from ...sql.dml import Delete
+    from ...sql.dml import Update
+    from ...sql.visitors import _TraverseInternalsType
 
 __all__ = ("Insert", "insert")
 
 
+def limit(limit: _LimitOffsetType) -> DMLLimitClause:
+    """apply a LIMIT to an UPDATE or DELETE statement
+
+    e.g.::
+
+        stmt = t.update().values(q="hi").ext(limit(5))
+
+    this supersedes the previous approach of using ``mysql_limit`` for
+    update/delete statements.
+
+    .. versionadded:: 2.1
+
+    """
+    return DMLLimitClause(limit)
+
+
+class DMLLimitClause(SyntaxExtension, ClauseElement):
+    stringify_dialect = "mysql"
+    __visit_name__ = "mysql_dml_limit_clause"
+
+    _traverse_internals: _TraverseInternalsType = [
+        ("_limit_clause", InternalTraversal.dp_clauseelement),
+    ]
+
+    def __init__(self, limit: _LimitOffsetType):
+        self._limit_clause = coercions.expect(
+            roles.LimitOffsetRole, limit, name=None, type_=None
+        )
+
+    def apply_to_update(self, update_stmt: Update) -> None:
+        update_stmt.apply_syntax_extension_point(
+            self.append_replacing_same_type, "post_criteria"
+        )
+
+    def apply_to_delete(self, delete_stmt: Delete) -> None:
+        delete_stmt.apply_syntax_extension_point(
+            self.append_replacing_same_type, "post_criteria"
+        )
+
+
 def insert(table: _DMLTableArgument) -> Insert:
     """Construct a MySQL/MariaDB-specific variant :class:`_mysql.Insert`
     construct.
index 1ee9ff077721ad8e60cfd7b0321ffdb77908bfdb..32043dd7bb4534c632e1cf7fa66bfc6dae708c89 100644 (file)
@@ -6135,14 +6135,6 @@ class SQLCompiler(Compiled):
 
         return text
 
-    def update_limit_clause(self, update_stmt):
-        """Provide a hook for MySQL to add LIMIT to the UPDATE"""
-        return None
-
-    def delete_limit_clause(self, delete_stmt):
-        """Provide a hook for MySQL to add LIMIT to the DELETE"""
-        return None
-
     def update_tables_clause(self, update_stmt, from_table, extra_froms, **kw):
         """Provide a hook to override the initial table clause
         in an UPDATE statement.
@@ -6165,6 +6157,36 @@ class SQLCompiler(Compiled):
             "criteria within UPDATE"
         )
 
+    def update_post_criteria_clause(self, update_stmt, **kw):
+        """provide a hook to override generation after the WHERE criteria
+        in an UPDATE statement
+
+        .. versionadded:: 2.1
+
+        """
+        if update_stmt._post_criteria_clause is not None:
+            return self.process(
+                update_stmt._post_criteria_clause,
+                **kw,
+            )
+        else:
+            return None
+
+    def delete_post_criteria_clause(self, delete_stmt, **kw):
+        """provide a hook to override generation after the WHERE criteria
+        in a DELETE statement
+
+        .. versionadded:: 2.1
+
+        """
+        if delete_stmt._post_criteria_clause is not None:
+            return self.process(
+                delete_stmt._post_criteria_clause,
+                **kw,
+            )
+        else:
+            return None
+
     def visit_update(self, update_stmt, visiting_cte=None, **kw):
         compile_state = update_stmt._compile_state_factory(
             update_stmt, self, **kw
@@ -6281,19 +6303,11 @@ class SQLCompiler(Compiled):
             if t:
                 text += " WHERE " + t
 
-        limit_clause = self.update_limit_clause(update_stmt)
-        if limit_clause:
-            text += " " + limit_clause
-
-        if update_stmt._post_criteria_clause is not None:
-            ulc = self.process(
-                update_stmt._post_criteria_clause,
-                from_linter=from_linter,
-                **kw,
-            )
-
-            if ulc:
-                text += " " + ulc
+        ulc = self.update_post_criteria_clause(
+            update_stmt, from_linter=from_linter, **kw
+        )
+        if ulc:
+            text += " " + ulc
 
         if (
             self.implicit_returning or update_stmt._returning
@@ -6443,18 +6457,11 @@ class SQLCompiler(Compiled):
             if t:
                 text += " WHERE " + t
 
-        limit_clause = self.delete_limit_clause(delete_stmt)
-        if limit_clause:
-            text += " " + limit_clause
-
-        if delete_stmt._post_criteria_clause is not None:
-            dlc = self.process(
-                delete_stmt._post_criteria_clause,
-                from_linter=from_linter,
-                **kw,
-            )
-            if dlc:
-                text += " " + dlc
+        dlc = self.delete_post_criteria_clause(
+            delete_stmt, from_linter=from_linter, **kw
+        )
+        if dlc:
+            text += " " + dlc
 
         if (
             self.implicit_returning or delete_stmt._returning
index 8387d4e07c67ef7bf551642f789469ec62cfeb1f..5c98be3f6ae1affab0fb1e138837b411fca637dd 100644 (file)
@@ -53,6 +53,7 @@ from sqlalchemy import UnicodeText
 from sqlalchemy import VARCHAR
 from sqlalchemy.dialects.mysql import base as mysql
 from sqlalchemy.dialects.mysql import insert
+from sqlalchemy.dialects.mysql import limit
 from sqlalchemy.dialects.mysql import match
 from sqlalchemy.orm import DeclarativeBase
 from sqlalchemy.orm import Mapped
@@ -72,6 +73,7 @@ from sqlalchemy.testing import expect_warnings
 from sqlalchemy.testing import fixtures
 from sqlalchemy.testing import mock
 from sqlalchemy.testing import Variation
+from sqlalchemy.testing.fixtures import CacheKeyFixture
 
 
 class ReservedWordFixture(AssertsCompiledSQL):
@@ -623,7 +625,114 @@ class CompileTest(ReservedWordFixture, fixtures.TestBase, AssertsCompiledSQL):
         )
 
 
-class SQLTest(fixtures.TestBase, AssertsCompiledSQL):
+class CustomExtensionTest(
+    fixtures.TestBase, AssertsCompiledSQL, fixtures.CacheKeySuite
+):
+    __dialect__ = "mysql"
+
+    @fixtures.CacheKeySuite.run_suite_tests
+    def test_dml_limit_cache_key(self):
+        t = sql.table("t", sql.column("col1"), sql.column("col2"))
+        return lambda: [
+            t.update().ext(limit(5)),
+            t.delete().ext(limit(5)),
+            t.update(),
+            t.delete(),
+        ]
+
+    def test_update_limit(self):
+        t = sql.table("t", sql.column("col1"), sql.column("col2"))
+
+        self.assert_compile(
+            t.update().values({"col1": 123}).ext(limit(5)),
+            "UPDATE t SET col1=%s LIMIT __[POSTCOMPILE_param_1]",
+            params={"col1": 123, "param_1": 5},
+            check_literal_execute={"param_1": 5},
+        )
+
+        # does not make sense but we want this to compile
+        self.assert_compile(
+            t.update().values({"col1": 123}).ext(limit(0)),
+            "UPDATE t SET col1=%s LIMIT __[POSTCOMPILE_param_1]",
+            params={"col1": 123, "param_1": 0},
+            check_literal_execute={"param_1": 0},
+        )
+
+        # many times is fine too
+        self.assert_compile(
+            t.update()
+            .values({"col1": 123})
+            .ext(limit(0))
+            .ext(limit(3))
+            .ext(limit(42)),
+            "UPDATE t SET col1=%s LIMIT __[POSTCOMPILE_param_1]",
+            params={"col1": 123, "param_1": 42},
+            check_literal_execute={"param_1": 42},
+        )
+
+    def test_delete_limit(self):
+        t = sql.table("t", sql.column("col1"), sql.column("col2"))
+
+        self.assert_compile(
+            t.delete().ext(limit(5)),
+            "DELETE FROM t LIMIT __[POSTCOMPILE_param_1]",
+            params={"param_1": 5},
+            check_literal_execute={"param_1": 5},
+        )
+
+        # does not make sense but we want this to compile
+        self.assert_compile(
+            t.delete().ext(limit(0)),
+            "DELETE FROM t LIMIT __[POSTCOMPILE_param_1]",
+            params={"param_1": 5},
+            check_literal_execute={"param_1": 0},
+        )
+
+        # many times is fine too
+        self.assert_compile(
+            t.delete().ext(limit(0)).ext(limit(3)).ext(limit(42)),
+            "DELETE FROM t LIMIT __[POSTCOMPILE_param_1]",
+            params={"param_1": 42},
+            check_literal_execute={"param_1": 42},
+        )
+
+    @testing.combinations((update,), (delete,))
+    def test_update_delete_limit_int_only(self, crud_fn):
+        t = sql.table("t", sql.column("col1"), sql.column("col2"))
+
+        with expect_raises(ValueError):
+            # note using coercions we get an immediate raise
+            # without having to wait for compilation
+            crud_fn(t).ext(limit("not an int"))
+
+    def test_legacy_update_limit_ext_interaction(self):
+        t = sql.table("t", sql.column("col1"), sql.column("col2"))
+
+        stmt = (
+            t.update()
+            .values({"col1": 123})
+            .with_dialect_options(mysql_limit=5)
+        )
+        stmt.apply_syntax_extension_point(
+            lambda existing: [literal_column("this is a clause")],
+            "post_criteria",
+        )
+        self.assert_compile(
+            stmt, "UPDATE t SET col1=%s LIMIT 5 this is a clause"
+        )
+
+    def test_legacy_delete_limit_ext_interaction(self):
+        t = sql.table("t", sql.column("col1"), sql.column("col2"))
+
+        stmt = t.delete().with_dialect_options(mysql_limit=5)
+        stmt.apply_syntax_extension_point(
+            lambda existing: [literal_column("this is a clause")],
+            "post_criteria",
+        )
+        self.assert_compile(stmt, "DELETE FROM t LIMIT 5 this is a clause")
+
+
+class SQLTest(fixtures.TestBase, AssertsCompiledSQL, CacheKeyFixture):
     """Tests MySQL-dialect specific compilation."""
 
     __dialect__ = mysql.dialect()
@@ -718,7 +827,7 @@ class SQLTest(fixtures.TestBase, AssertsCompiledSQL):
             dialect=mysql.dialect(),
         )
 
-    def test_update_limit(self):
+    def test_legacy_update_limit(self):
         t = sql.table("t", sql.column("col1"), sql.column("col2"))
 
         self.assert_compile(
@@ -752,7 +861,7 @@ class SQLTest(fixtures.TestBase, AssertsCompiledSQL):
             "UPDATE t SET col1=%s WHERE t.col2 = %s LIMIT 1",
         )
 
-    def test_delete_limit(self):
+    def test_legacy_delete_limit(self):
         t = sql.table("t", sql.column("col1"), sql.column("col2"))
 
         self.assert_compile(t.delete(), "DELETE FROM t")
@@ -777,7 +886,7 @@ class SQLTest(fixtures.TestBase, AssertsCompiledSQL):
         )
 
     @testing.combinations((update,), (delete,))
-    def test_update_delete_limit_int_only(self, crud_fn):
+    def test_legacy_update_delete_limit_int_only(self, crud_fn):
         t = sql.table("t", sql.column("col1"), sql.column("col2"))
 
         with expect_raises(ValueError):
index 9cbc38378fbfe23c660086e04635c080dfcc9e6d..973fe3dbc29ec8a06380dceb978d0bd995f50630 100644 (file)
@@ -5,6 +5,7 @@ from sqlalchemy import Boolean
 from sqlalchemy import cast
 from sqlalchemy import Column
 from sqlalchemy import Computed
+from sqlalchemy import delete
 from sqlalchemy import exc
 from sqlalchemy import false
 from sqlalchemy import ForeignKey
@@ -16,12 +17,16 @@ from sqlalchemy import select
 from sqlalchemy import String
 from sqlalchemy import Table
 from sqlalchemy import true
+from sqlalchemy import update
+from sqlalchemy.dialects.mysql import limit
 from sqlalchemy.testing import assert_raises
 from sqlalchemy.testing import combinations
 from sqlalchemy.testing import eq_
 from sqlalchemy.testing import expect_warnings
 from sqlalchemy.testing import fixtures
 from sqlalchemy.testing import is_
+from sqlalchemy.testing.assertsql import CompiledSQL
+from sqlalchemy.testing.fixtures import fixture_session
 
 
 class IdiosyncrasyTest(fixtures.TestBase):
@@ -305,3 +310,127 @@ class ComputedTest(fixtures.TestBase):
         # Create and then drop table
         connection.execute(schema.CreateTable(t))
         connection.execute(schema.DropTable(t))
+
+
+class LimitORMTest(fixtures.MappedTest):
+    __only_on__ = "mysql >= 5.7", "mariadb"
+    __backend__ = True
+
+    @classmethod
+    def define_tables(cls, metadata):
+        Table(
+            "users",
+            metadata,
+            Column("id", Integer, primary_key=True),
+            Column("name", String(32)),
+            Column("age_int", Integer),
+        )
+
+    @classmethod
+    def setup_classes(cls):
+        class User(cls.Comparable):
+            pass
+
+    @classmethod
+    def insert_data(cls, connection):
+        users = cls.tables.users
+
+        connection.execute(
+            users.insert(),
+            [
+                dict(id=1, name="john", age_int=25),
+                dict(id=2, name="jack", age_int=47),
+                dict(id=3, name="jill", age_int=29),
+                dict(id=4, name="jane", age_int=37),
+            ],
+        )
+
+    @classmethod
+    def setup_mappers(cls):
+        User = cls.classes.User
+        users = cls.tables.users
+
+        cls.mapper_registry.map_imperatively(
+            User,
+            users,
+            properties={
+                "age": users.c.age_int,
+            },
+        )
+
+    def test_update_limit_orm_select(self):
+        User = self.classes.User
+
+        s = fixture_session()
+        with self.sql_execution_asserter() as asserter:
+            s.execute(
+                update(User)
+                .where(User.name.startswith("j"))
+                .ext(limit(2))
+                .values({"age": User.age + 3})
+            )
+
+        asserter.assert_(
+            CompiledSQL(
+                "UPDATE users SET age_int=(users.age_int + %s) "
+                "WHERE (users.name LIKE concat(%s, '%%')) "
+                "LIMIT __[POSTCOMPILE_param_1]",
+                [{"age_int_1": 3, "name_1": "j", "param_1": 2}],
+                dialect="mysql",
+            ),
+        )
+
+    def test_delete_limit_orm_select(self):
+        User = self.classes.User
+
+        s = fixture_session()
+        with self.sql_execution_asserter() as asserter:
+            s.execute(
+                delete(User).where(User.name.startswith("j")).ext(limit(2))
+            )
+
+        asserter.assert_(
+            CompiledSQL(
+                "DELETE FROM users WHERE (users.name LIKE concat(%s, '%%')) "
+                "LIMIT __[POSTCOMPILE_param_1]",
+                [{"name_1": "j", "param_1": 2}],
+                dialect="mysql",
+            ),
+        )
+
+    def test_update_limit_legacy_query(self):
+        User = self.classes.User
+
+        s = fixture_session()
+        with self.sql_execution_asserter() as asserter:
+            s.query(User).where(User.name.startswith("j")).ext(
+                limit(2)
+            ).update({"age": User.age + 3})
+
+        asserter.assert_(
+            CompiledSQL(
+                "UPDATE users SET age_int=(users.age_int + %s) "
+                "WHERE (users.name LIKE concat(%s, '%%')) "
+                "LIMIT __[POSTCOMPILE_param_1]",
+                [{"age_int_1": 3, "name_1": "j", "param_1": 2}],
+                dialect="mysql",
+            ),
+        )
+
+    def test_delete_limit_legacy_query(self):
+        User = self.classes.User
+
+        s = fixture_session()
+        with self.sql_execution_asserter() as asserter:
+            s.query(User).where(User.name.startswith("j")).ext(
+                limit(2)
+            ).delete()
+
+        asserter.assert_(
+            CompiledSQL(
+                "DELETE FROM users WHERE (users.name LIKE concat(%s, '%%')) "
+                "LIMIT __[POSTCOMPILE_param_1]",
+                [{"name_1": "j", "param_1": 2}],
+                dialect="mysql",
+            ),
+        )