From: Mike Bayer Date: Mon, 24 Feb 2025 20:10:54 +0000 (-0500) Subject: implement mysql limit() for UPDATE/DELETE DML (patch 2) X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=33be2722905f74562cb47cf6c23948065ae91e47;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git implement mysql limit() for UPDATE/DELETE DML (patch 2) Added new construct :func:`_mysql.limit` which can be applied to any :func:`_sql.update` or :func:`_sql.delete` to provide the LIMIT keyword to UPDATE and DELETE. This new construct supersedes the use of the "mysql_limit" dialect keyword argument. Change-Id: Ie10c2f273432b0c8881a48f5b287f0566dde6ec3 --- diff --git a/doc/build/changelog/unreleased_21/mysql_limit.rst b/doc/build/changelog/unreleased_21/mysql_limit.rst new file mode 100644 index 0000000000..cf74e97a44 --- /dev/null +++ b/doc/build/changelog/unreleased_21/mysql_limit.rst @@ -0,0 +1,8 @@ +.. change:: + :tags: feature, mysql + + Added new construct :func:`_mysql.limit` which can be applied to any + :func:`_sql.update` or :func:`_sql.delete` to provide the LIMIT keyword to + UPDATE and DELETE. This new construct supersedes the use of the + "mysql_limit" dialect keyword argument. + diff --git a/lib/sqlalchemy/dialects/mysql/__init__.py b/lib/sqlalchemy/dialects/mysql/__init__.py index 9174c54413..d722c1d30c 100644 --- a/lib/sqlalchemy/dialects/mysql/__init__.py +++ b/lib/sqlalchemy/dialects/mysql/__init__.py @@ -52,6 +52,7 @@ from .base import VARCHAR from .base import YEAR from .dml import Insert from .dml import insert +from .dml import limit from .expression import match from .mariadb import INET4 from .mariadb import INET6 diff --git a/lib/sqlalchemy/dialects/mysql/base.py b/lib/sqlalchemy/dialects/mysql/base.py index b57a1e1343..7838b455b9 100644 --- a/lib/sqlalchemy/dialects/mysql/base.py +++ b/lib/sqlalchemy/dialects/mysql/base.py @@ -511,16 +511,25 @@ available. select(...).prefix_with(["HIGH_PRIORITY", "SQL_SMALL_RESULT"]) -* UPDATE with LIMIT:: +* UPDATE + with LIMIT:: + + from sqlalchemy.dialects.mysql import limit + + update(...).ext(limit(10)) - update(...).with_dialect_options(mysql_limit=10, mariadb_limit=10) + .. versionchanged:: 2.1 the :func:`_mysql.limit()` extension supersedes the + previous use of ``mysql_limit`` * DELETE with LIMIT:: - delete(...).with_dialect_options(mysql_limit=10, mariadb_limit=10) + from sqlalchemy.dialects.mysql import limit - .. versionadded:: 2.0.37 Added delete with limit + delete(...).ext(limit(10)) + + .. versionchanged:: 2.1 the :func:`_mysql.limit()` extension supersedes the + previous use of ``mysql_limit`` * optimizer hints, use :meth:`_expression.Select.prefix_with` and :meth:`_query.Query.prefix_with`:: @@ -1750,19 +1759,35 @@ class MySQLCompiler(compiler.SQLCompiler): # No offset provided, so just use the limit return " \n LIMIT %s" % (self.process(limit_clause, **kw),) - def update_limit_clause(self, update_stmt): + def update_post_criteria_clause(self, update_stmt, **kw): limit = update_stmt.kwargs.get("%s_limit" % self.dialect.name, None) + supertext = super().update_post_criteria_clause(update_stmt, **kw) + if limit is not None: - return f"LIMIT {int(limit)}" + limit_text = f"LIMIT {int(limit)}" + if supertext is not None: + return f"{limit_text} {supertext}" + else: + return limit_text else: - return None + return supertext - def delete_limit_clause(self, delete_stmt): + def delete_post_criteria_clause(self, delete_stmt, **kw): limit = delete_stmt.kwargs.get("%s_limit" % self.dialect.name, None) + supertext = super().delete_post_criteria_clause(delete_stmt, **kw) + if limit is not None: - return f"LIMIT {int(limit)}" + limit_text = f"LIMIT {int(limit)}" + if supertext is not None: + return f"{limit_text} {supertext}" + else: + return limit_text else: - return None + return supertext + + def visit_mysql_dml_limit_clause(self, element, **kw): + kw["literal_execute"] = True + return f"LIMIT {self.process(element._limit_clause, **kw)}" def update_tables_clause(self, update_stmt, from_table, extra_froms, **kw): kw["asfrom"] = True diff --git a/lib/sqlalchemy/dialects/mysql/dml.py b/lib/sqlalchemy/dialects/mysql/dml.py index cceb0818f9..f3be3c395d 100644 --- a/lib/sqlalchemy/dialects/mysql/dml.py +++ b/lib/sqlalchemy/dialects/mysql/dml.py @@ -12,26 +12,76 @@ from typing import List from typing import Mapping from typing import Optional from typing import Tuple +from typing import TYPE_CHECKING from typing import Union from ... import exc from ... import util +from ...sql import coercions +from ...sql import roles from ...sql._typing import _DMLTableArgument from ...sql.base import _exclusive_against from ...sql.base import _generative from ...sql.base import ColumnCollection from ...sql.base import ReadOnlyColumnCollection +from ...sql.base import SyntaxExtension from ...sql.dml import Insert as StandardInsert from ...sql.elements import ClauseElement from ...sql.elements import KeyedColumnElement from ...sql.expression import alias from ...sql.selectable import NamedFromClause +from ...sql.visitors import InternalTraversal from ...util.typing import Self +if TYPE_CHECKING: + from ...sql._typing import _LimitOffsetType + from ...sql.dml import Delete + from ...sql.dml import Update + from ...sql.visitors import _TraverseInternalsType __all__ = ("Insert", "insert") +def limit(limit: _LimitOffsetType) -> DMLLimitClause: + """apply a LIMIT to an UPDATE or DELETE statement + + e.g.:: + + stmt = t.update().values(q="hi").ext(limit(5)) + + this supersedes the previous approach of using ``mysql_limit`` for + update/delete statements. + + .. versionadded:: 2.1 + + """ + return DMLLimitClause(limit) + + +class DMLLimitClause(SyntaxExtension, ClauseElement): + stringify_dialect = "mysql" + __visit_name__ = "mysql_dml_limit_clause" + + _traverse_internals: _TraverseInternalsType = [ + ("_limit_clause", InternalTraversal.dp_clauseelement), + ] + + def __init__(self, limit: _LimitOffsetType): + self._limit_clause = coercions.expect( + roles.LimitOffsetRole, limit, name=None, type_=None + ) + + def apply_to_update(self, update_stmt: Update) -> None: + update_stmt.apply_syntax_extension_point( + self.append_replacing_same_type, "post_criteria" + ) + + def apply_to_delete(self, delete_stmt: Delete) -> None: + delete_stmt.apply_syntax_extension_point( + self.append_replacing_same_type, "post_criteria" + ) + + def insert(table: _DMLTableArgument) -> Insert: """Construct a MySQL/MariaDB-specific variant :class:`_mysql.Insert` construct. diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 1ee9ff0777..32043dd7bb 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -6135,14 +6135,6 @@ class SQLCompiler(Compiled): return text - def update_limit_clause(self, update_stmt): - """Provide a hook for MySQL to add LIMIT to the UPDATE""" - return None - - def delete_limit_clause(self, delete_stmt): - """Provide a hook for MySQL to add LIMIT to the DELETE""" - return None - def update_tables_clause(self, update_stmt, from_table, extra_froms, **kw): """Provide a hook to override the initial table clause in an UPDATE statement. @@ -6165,6 +6157,36 @@ class SQLCompiler(Compiled): "criteria within UPDATE" ) + def update_post_criteria_clause(self, update_stmt, **kw): + """provide a hook to override generation after the WHERE criteria + in an UPDATE statement + + .. versionadded:: 2.1 + + """ + if update_stmt._post_criteria_clause is not None: + return self.process( + update_stmt._post_criteria_clause, + **kw, + ) + else: + return None + + def delete_post_criteria_clause(self, delete_stmt, **kw): + """provide a hook to override generation after the WHERE criteria + in a DELETE statement + + .. versionadded:: 2.1 + + """ + if delete_stmt._post_criteria_clause is not None: + return self.process( + delete_stmt._post_criteria_clause, + **kw, + ) + else: + return None + def visit_update(self, update_stmt, visiting_cte=None, **kw): compile_state = update_stmt._compile_state_factory( update_stmt, self, **kw @@ -6281,19 +6303,11 @@ class SQLCompiler(Compiled): if t: text += " WHERE " + t - limit_clause = self.update_limit_clause(update_stmt) - if limit_clause: - text += " " + limit_clause - - if update_stmt._post_criteria_clause is not None: - ulc = self.process( - update_stmt._post_criteria_clause, - from_linter=from_linter, - **kw, - ) - - if ulc: - text += " " + ulc + ulc = self.update_post_criteria_clause( + update_stmt, from_linter=from_linter, **kw + ) + if ulc: + text += " " + ulc if ( self.implicit_returning or update_stmt._returning @@ -6443,18 +6457,11 @@ class SQLCompiler(Compiled): if t: text += " WHERE " + t - limit_clause = self.delete_limit_clause(delete_stmt) - if limit_clause: - text += " " + limit_clause - - if delete_stmt._post_criteria_clause is not None: - dlc = self.process( - delete_stmt._post_criteria_clause, - from_linter=from_linter, - **kw, - ) - if dlc: - text += " " + dlc + dlc = self.delete_post_criteria_clause( + delete_stmt, from_linter=from_linter, **kw + ) + if dlc: + text += " " + dlc if ( self.implicit_returning or delete_stmt._returning diff --git a/test/dialect/mysql/test_compiler.py b/test/dialect/mysql/test_compiler.py index 8387d4e07c..5c98be3f6a 100644 --- a/test/dialect/mysql/test_compiler.py +++ b/test/dialect/mysql/test_compiler.py @@ -53,6 +53,7 @@ from sqlalchemy import UnicodeText from sqlalchemy import VARCHAR from sqlalchemy.dialects.mysql import base as mysql from sqlalchemy.dialects.mysql import insert +from sqlalchemy.dialects.mysql import limit from sqlalchemy.dialects.mysql import match from sqlalchemy.orm import DeclarativeBase from sqlalchemy.orm import Mapped @@ -72,6 +73,7 @@ from sqlalchemy.testing import expect_warnings from sqlalchemy.testing import fixtures from sqlalchemy.testing import mock from sqlalchemy.testing import Variation +from sqlalchemy.testing.fixtures import CacheKeyFixture class ReservedWordFixture(AssertsCompiledSQL): @@ -623,7 +625,114 @@ class CompileTest(ReservedWordFixture, fixtures.TestBase, AssertsCompiledSQL): ) -class SQLTest(fixtures.TestBase, AssertsCompiledSQL): +class CustomExtensionTest( + fixtures.TestBase, AssertsCompiledSQL, fixtures.CacheKeySuite +): + __dialect__ = "mysql" + + @fixtures.CacheKeySuite.run_suite_tests + def test_dml_limit_cache_key(self): + t = sql.table("t", sql.column("col1"), sql.column("col2")) + return lambda: [ + t.update().ext(limit(5)), + t.delete().ext(limit(5)), + t.update(), + t.delete(), + ] + + def test_update_limit(self): + t = sql.table("t", sql.column("col1"), sql.column("col2")) + + self.assert_compile( + t.update().values({"col1": 123}).ext(limit(5)), + "UPDATE t SET col1=%s LIMIT __[POSTCOMPILE_param_1]", + params={"col1": 123, "param_1": 5}, + check_literal_execute={"param_1": 5}, + ) + + # does not make sense but we want this to compile + self.assert_compile( + t.update().values({"col1": 123}).ext(limit(0)), + "UPDATE t SET col1=%s LIMIT __[POSTCOMPILE_param_1]", + params={"col1": 123, "param_1": 0}, + check_literal_execute={"param_1": 0}, + ) + + # many times is fine too + self.assert_compile( + t.update() + .values({"col1": 123}) + .ext(limit(0)) + .ext(limit(3)) + .ext(limit(42)), + "UPDATE t SET col1=%s LIMIT __[POSTCOMPILE_param_1]", + params={"col1": 123, "param_1": 42}, + check_literal_execute={"param_1": 42}, + ) + + def test_delete_limit(self): + t = sql.table("t", sql.column("col1"), sql.column("col2")) + + self.assert_compile( + t.delete().ext(limit(5)), + "DELETE FROM t LIMIT __[POSTCOMPILE_param_1]", + params={"param_1": 5}, + check_literal_execute={"param_1": 5}, + ) + + # does not make sense but we want this to compile + self.assert_compile( + t.delete().ext(limit(0)), + "DELETE FROM t LIMIT __[POSTCOMPILE_param_1]", + params={"param_1": 5}, + check_literal_execute={"param_1": 0}, + ) + + # many times is fine too + self.assert_compile( + t.delete().ext(limit(0)).ext(limit(3)).ext(limit(42)), + "DELETE FROM t LIMIT __[POSTCOMPILE_param_1]", + params={"param_1": 42}, + check_literal_execute={"param_1": 42}, + ) + + @testing.combinations((update,), (delete,)) + def test_update_delete_limit_int_only(self, crud_fn): + t = sql.table("t", sql.column("col1"), sql.column("col2")) + + with expect_raises(ValueError): + # note using coercions we get an immediate raise + # without having to wait for compilation + crud_fn(t).ext(limit("not an int")) + + def test_legacy_update_limit_ext_interaction(self): + t = sql.table("t", sql.column("col1"), sql.column("col2")) + + stmt = ( + t.update() + .values({"col1": 123}) + .with_dialect_options(mysql_limit=5) + ) + stmt.apply_syntax_extension_point( + lambda existing: [literal_column("this is a clause")], + "post_criteria", + ) + self.assert_compile( + stmt, "UPDATE t SET col1=%s LIMIT 5 this is a clause" + ) + + def test_legacy_delete_limit_ext_interaction(self): + t = sql.table("t", sql.column("col1"), sql.column("col2")) + + stmt = t.delete().with_dialect_options(mysql_limit=5) + stmt.apply_syntax_extension_point( + lambda existing: [literal_column("this is a clause")], + "post_criteria", + ) + self.assert_compile(stmt, "DELETE FROM t LIMIT 5 this is a clause") + + +class SQLTest(fixtures.TestBase, AssertsCompiledSQL, CacheKeyFixture): """Tests MySQL-dialect specific compilation.""" __dialect__ = mysql.dialect() @@ -718,7 +827,7 @@ class SQLTest(fixtures.TestBase, AssertsCompiledSQL): dialect=mysql.dialect(), ) - def test_update_limit(self): + def test_legacy_update_limit(self): t = sql.table("t", sql.column("col1"), sql.column("col2")) self.assert_compile( @@ -752,7 +861,7 @@ class SQLTest(fixtures.TestBase, AssertsCompiledSQL): "UPDATE t SET col1=%s WHERE t.col2 = %s LIMIT 1", ) - def test_delete_limit(self): + def test_legacy_delete_limit(self): t = sql.table("t", sql.column("col1"), sql.column("col2")) self.assert_compile(t.delete(), "DELETE FROM t") @@ -777,7 +886,7 @@ class SQLTest(fixtures.TestBase, AssertsCompiledSQL): ) @testing.combinations((update,), (delete,)) - def test_update_delete_limit_int_only(self, crud_fn): + def test_legacy_update_delete_limit_int_only(self, crud_fn): t = sql.table("t", sql.column("col1"), sql.column("col2")) with expect_raises(ValueError): diff --git a/test/dialect/mysql/test_query.py b/test/dialect/mysql/test_query.py index 9cbc38378f..973fe3dbc2 100644 --- a/test/dialect/mysql/test_query.py +++ b/test/dialect/mysql/test_query.py @@ -5,6 +5,7 @@ from sqlalchemy import Boolean from sqlalchemy import cast from sqlalchemy import Column from sqlalchemy import Computed +from sqlalchemy import delete from sqlalchemy import exc from sqlalchemy import false from sqlalchemy import ForeignKey @@ -16,12 +17,16 @@ from sqlalchemy import select from sqlalchemy import String from sqlalchemy import Table from sqlalchemy import true +from sqlalchemy import update +from sqlalchemy.dialects.mysql import limit from sqlalchemy.testing import assert_raises from sqlalchemy.testing import combinations from sqlalchemy.testing import eq_ from sqlalchemy.testing import expect_warnings from sqlalchemy.testing import fixtures from sqlalchemy.testing import is_ +from sqlalchemy.testing.assertsql import CompiledSQL +from sqlalchemy.testing.fixtures import fixture_session class IdiosyncrasyTest(fixtures.TestBase): @@ -305,3 +310,127 @@ class ComputedTest(fixtures.TestBase): # Create and then drop table connection.execute(schema.CreateTable(t)) connection.execute(schema.DropTable(t)) + + +class LimitORMTest(fixtures.MappedTest): + __only_on__ = "mysql >= 5.7", "mariadb" + __backend__ = True + + @classmethod + def define_tables(cls, metadata): + Table( + "users", + metadata, + Column("id", Integer, primary_key=True), + Column("name", String(32)), + Column("age_int", Integer), + ) + + @classmethod + def setup_classes(cls): + class User(cls.Comparable): + pass + + @classmethod + def insert_data(cls, connection): + users = cls.tables.users + + connection.execute( + users.insert(), + [ + dict(id=1, name="john", age_int=25), + dict(id=2, name="jack", age_int=47), + dict(id=3, name="jill", age_int=29), + dict(id=4, name="jane", age_int=37), + ], + ) + + @classmethod + def setup_mappers(cls): + User = cls.classes.User + users = cls.tables.users + + cls.mapper_registry.map_imperatively( + User, + users, + properties={ + "age": users.c.age_int, + }, + ) + + def test_update_limit_orm_select(self): + User = self.classes.User + + s = fixture_session() + with self.sql_execution_asserter() as asserter: + s.execute( + update(User) + .where(User.name.startswith("j")) + .ext(limit(2)) + .values({"age": User.age + 3}) + ) + + asserter.assert_( + CompiledSQL( + "UPDATE users SET age_int=(users.age_int + %s) " + "WHERE (users.name LIKE concat(%s, '%%')) " + "LIMIT __[POSTCOMPILE_param_1]", + [{"age_int_1": 3, "name_1": "j", "param_1": 2}], + dialect="mysql", + ), + ) + + def test_delete_limit_orm_select(self): + User = self.classes.User + + s = fixture_session() + with self.sql_execution_asserter() as asserter: + s.execute( + delete(User).where(User.name.startswith("j")).ext(limit(2)) + ) + + asserter.assert_( + CompiledSQL( + "DELETE FROM users WHERE (users.name LIKE concat(%s, '%%')) " + "LIMIT __[POSTCOMPILE_param_1]", + [{"name_1": "j", "param_1": 2}], + dialect="mysql", + ), + ) + + def test_update_limit_legacy_query(self): + User = self.classes.User + + s = fixture_session() + with self.sql_execution_asserter() as asserter: + s.query(User).where(User.name.startswith("j")).ext( + limit(2) + ).update({"age": User.age + 3}) + + asserter.assert_( + CompiledSQL( + "UPDATE users SET age_int=(users.age_int + %s) " + "WHERE (users.name LIKE concat(%s, '%%')) " + "LIMIT __[POSTCOMPILE_param_1]", + [{"age_int_1": 3, "name_1": "j", "param_1": 2}], + dialect="mysql", + ), + ) + + def test_delete_limit_legacy_query(self): + User = self.classes.User + + s = fixture_session() + with self.sql_execution_asserter() as asserter: + s.query(User).where(User.name.startswith("j")).ext( + limit(2) + ).delete() + + asserter.assert_( + CompiledSQL( + "DELETE FROM users WHERE (users.name LIKE concat(%s, '%%')) " + "LIMIT __[POSTCOMPILE_param_1]", + [{"name_1": "j", "param_1": 2}], + dialect="mysql", + ), + )