]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
add delete limit to mysql
authorPablo Nicolas Estevez <pablo22estevez@gmail.com>
Sun, 1 Dec 2024 19:16:24 +0000 (16:16 -0300)
committerPablo Nicolas Estevez <pablo22estevez@gmail.com>
Sun, 1 Dec 2024 19:16:24 +0000 (16:16 -0300)
lib/sqlalchemy/dialects/mysql/base.py
lib/sqlalchemy/orm/query.py
lib/sqlalchemy/sql/compiler.py
test/dialect/mysql/test_compiler.py
test/orm/dml/test_update_delete_where.py

index c834495759e9c3f5c04f878df243fe8456dfaf78..282e36a0841da4b16b3b823ad3183f48873d024d 100644 (file)
@@ -473,6 +473,10 @@ available.
 
     update(..., mysql_limit=10, mariadb_limit=10)
 
+* DELETE with LIMIT::
+
+    delete(..., mysql_limit=10, mariadb_limit=10)
+
 * optimizer hints, use :meth:`_expression.Select.prefix_with` and
   :meth:`_query.Query.prefix_with`::
 
@@ -1682,6 +1686,13 @@ class MySQLCompiler(compiler.SQLCompiler):
         else:
             return None
 
+    def delete_limit_clause(self, delete_stmt):
+        limit = delete_stmt.kwargs.get("%s_limit" % self.dialect.name, None)
+        if limit:
+            return "LIMIT %s" % limit
+        else:
+            return None
+
     def update_tables_clause(self, update_stmt, from_table, extra_froms, **kw):
         kw["asfrom"] = True
         return ", ".join(
index 8f58143e6145e5d26b19dc272c4255b84a8c6032..9d23decda6db450387a18753070fab6a02efe532 100644 (file)
@@ -3135,7 +3135,9 @@ class Query(
         )
 
     def delete(
-        self, synchronize_session: SynchronizeSessionArgument = "auto"
+        self,
+        synchronize_session: SynchronizeSessionArgument = "auto",
+        delete_args: Optional[Dict[Any, Any]] = None,
     ) -> int:
         r"""Perform a DELETE with an arbitrary WHERE clause.
 
@@ -3160,6 +3162,12 @@ class Query(
          :ref:`orm_expression_update_delete` for a discussion of these
          strategies.
 
+        :param delete_args: Optional dictionary, if present will be passed
+         to the underlying :func:`_expression.delete`
+         construct as the ``**kw`` for
+         the object.  May be used to pass dialect-specific arguments such
+         as ``mysql_limit``.
+
         :return: the count of rows matched as returned by the database's
           "row count" feature.
 
@@ -3169,7 +3177,9 @@ class Query(
 
         """
 
-        bulk_del = BulkDelete(self)
+        delete_args = delete_args or {}
+
+        bulk_del = BulkDelete(self, delete_args)
         if self.dispatch.before_compile_delete:
             for fn in self.dispatch.before_compile_delete:
                 new_query = fn(bulk_del.query, bulk_del)
@@ -3179,6 +3189,10 @@ class Query(
                 self = bulk_del.query
 
         delete_ = sql.delete(*self._raw_columns)  # type: ignore
+
+        if delete_args:
+            delete_ = delete_.with_dialect_options(**delete_args)
+
         delete_._where_criteria = self._where_criteria
         result: CursorResult[Any] = self.session.execute(
             delete_,
@@ -3409,6 +3423,14 @@ class BulkUpdate(BulkUD):
 class BulkDelete(BulkUD):
     """BulkUD which handles DELETEs."""
 
+    def __init__(
+        self,
+        query: Query[Any],
+        delete_kwargs: Optional[Dict[Any, Any]],
+    ):
+        super().__init__(query)
+        self.delete_kwargs = delete_kwargs
+
 
 class RowReturningQuery(Query[Row[Unpack[_Ts]]]):
     if TYPE_CHECKING:
index 647d38e64015749014759492d5d2399c5f060f47..9783af7bf2b53d85719c1cf5f90ace80e25e1830 100644 (file)
@@ -6102,6 +6102,10 @@ class SQLCompiler(Compiled):
         """Provide a hook for MySQL to add LIMIT to the UPDATE"""
         return None
 
+    def delete_limit_clause(self, delete_stmt):
+        """Provide a hook for MySQL to add LIMIT to the DELETE"""
+        return None
+
     def update_tables_clause(self, update_stmt, from_table, extra_froms, **kw):
         """Provide a hook to override the initial table clause
         in an UPDATE statement.
@@ -6394,6 +6398,10 @@ class SQLCompiler(Compiled):
             if t:
                 text += " WHERE " + t
 
+        limit_clause = self.delete_limit_clause(delete_stmt)
+        if limit_clause:
+            text += " " + limit_clause
+
         if (
             self.implicit_returning or delete_stmt._returning
         ) and not self.returning_precedes_values:
index f0dcb5838847c92c0884e3aa0899120239dd766b..34b747ed1577dbf5bb65ba40f73278db1a36ec91 100644 (file)
@@ -738,6 +738,25 @@ class SQLTest(fixtures.TestBase, AssertsCompiledSQL):
             "UPDATE t SET col1=%s WHERE t.col2 = %s LIMIT 1",
         )
 
+    def test_delete_limit(self):
+        t = sql.table("t", sql.column("col1"), sql.column("col2"))
+
+        self.assert_compile(t.delete(), "DELETE FROM t")
+        self.assert_compile(
+            t.delete().with_dialect_options(mysql_limit=5),
+            "DELETE FROM t SET col1=%s LIMIT 5",
+        )
+        self.assert_compile(
+            t.delete().with_dialect_options(mysql_limit=None),
+            "DELETE FROM t",
+        )
+        self.assert_compile(
+            t.delete()
+            .where(t.c.col2 == 456)
+            .with_dialect_options(mysql_limit=1),
+            "DELETE FRM t WHERE t.col2 = %s LIMIT 1",
+        )
+
     def test_utc_timestamp(self):
         self.assert_compile(func.utc_timestamp(), "utc_timestamp()")
 
index da8efa44fa4f22fa574b41b328c1265a2abf8588..daa2d5bd0c641a968c50e53e075829364e2b5479 100644 (file)
@@ -2585,6 +2585,27 @@ class UpdateDeleteFromTest(fixtures.MappedTest):
             ],
         )
 
+    def test_delete_args(self):
+        Data = self.classes.Data
+        session = fixture_session()
+        delete_args = {"mysql_limit": 1}
+
+        m1 = testing.mock.Mock()
+
+        @event.listens_for(session, "after_bulk_update")
+        def do_orm_execute(bulk_ud):
+            delete_stmt = (
+                bulk_ud.result.context.compiled.compile_state.statement
+            )
+            m1(delete_stmt)
+
+        q = session.query(Data)
+        q.delete(delete_args=delete_args)
+
+        delete_stmt = m1.mock_calls[0][1][0]
+
+        eq_(delete_stmt.dialect_kwargs, delete_args)
+
 
 class ExpressionUpdateTest(fixtures.MappedTest):
     @classmethod