]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Add with_for_update mysql new functionalities
authorRobotScribe <quentinso@theodo.fr>
Wed, 29 Apr 2020 19:22:59 +0000 (15:22 -0400)
committerGord Thompson <gord@gordthompson.com>
Fri, 15 May 2020 21:50:32 +0000 (15:50 -0600)
Fixes: #4860
# Description
Add nowait, skip_lock, of arguments to for_update_clause for mysql

### Checklist

This pull request is:

- [ ] A documentation / typographical error fix
- Good to go, no issue or tests are needed
- [ ] A short code fix
- please include the issue number, and create an issue if none exists, which
  must include a complete example of the issue.  one line code fixes without an
  issue and demonstration will not be accepted.
- Please include: `Fixes: #<issue number>` in the commit message
- please include tests.   one line code fixes without tests will not be accepted.
- [x] A new feature implementation
- please include the issue number, and create an issue if none exists, which must
  include a complete example of how the feature would look.
- Please include: `Fixes: #<issue number>` in the commit message
- please include tests.

**Have a nice day!**

Closes: #5290
Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/5290
Pull-request-sha: 490e822e73e92ffe63cf45df9c49f3b31af1954d

Change-Id: Ibd2acc47b538c601c69c8fb954776035ecab4c6c

doc/build/changelog/unreleased_13/4860.rst [new file with mode: 0644]
lib/sqlalchemy/dialects/mysql/base.py
lib/sqlalchemy/engine/default.py
test/dialect/mysql/test_compiler.py
test/dialect/mysql/test_for_update.py
test/dialect/postgresql/test_compiler.py
test/requirements.py

diff --git a/doc/build/changelog/unreleased_13/4860.rst b/doc/build/changelog/unreleased_13/4860.rst
new file mode 100644 (file)
index 0000000..b526ce3
--- /dev/null
@@ -0,0 +1,6 @@
+.. change::
+    :tags: usecase, mysql
+    :tickets: 4860
+
+    Implemented row-level locking support for mysql.  Pull request courtesy
+    Quentin Somerville.
\ No newline at end of file
index dca7b9a00157e0767fef05382197fc764fae9bdc..d009d656edefac2f1ae4b5bd02efa8a8bd008c56 100644 (file)
@@ -808,6 +808,7 @@ from ...sql import coercions
 from ...sql import compiler
 from ...sql import elements
 from ...sql import roles
+from ...sql import util as sql_util
 from ...types import BINARY
 from ...types import BLOB
 from ...types import BOOLEAN
@@ -1494,9 +1495,28 @@ class MySQLCompiler(compiler.SQLCompiler):
 
     def for_update_clause(self, select, **kw):
         if select._for_update_arg.read:
-            return " LOCK IN SHARE MODE"
+            tmp = " LOCK IN SHARE MODE"
         else:
-            return " FOR UPDATE"
+            tmp = " FOR UPDATE"
+
+        if select._for_update_arg.of and self.dialect.supports_for_update_of:
+
+            tables = util.OrderedSet()
+            for c in select._for_update_arg.of:
+                tables.update(sql_util.surface_selectables_only(c))
+
+            tmp += " OF " + ", ".join(
+                self.process(table, ashint=True, use_schema=False, **kw)
+                for table in tables
+            )
+
+        if select._for_update_arg.nowait:
+            tmp += " NOWAIT"
+
+        if select._for_update_arg.skip_locked and self.dialect._is_mysql:
+            tmp += " SKIP LOCKED"
+
+        return tmp
 
     def limit_clause(self, select, **kw):
         # MySQL supports:
@@ -2211,6 +2231,9 @@ class MySQLDialect(default.DefaultDialect):
 
     sequences_optional = True
 
+    supports_for_update_of = False  # default for MySQL ...
+    # ... may be updated to True for MySQL 8+ in initialize()
+
     supports_sane_rowcount = True
     supports_sane_multi_rowcount = False
     supports_multivalues_insert = True
@@ -2526,6 +2549,10 @@ class MySQLDialect(default.DefaultDialect):
             self._is_mariadb and self.server_version_info >= (10, 3)
         )
 
+        self.supports_for_update_of = (
+            self._is_mysql and self.server_version_info >= (8,)
+        )
+
         self._needs_correct_for_88718_96365 = (
             not self._is_mariadb and self.server_version_info >= (8,)
         )
index 20f73111602800c7a3e44f1ca3c4fc076ea0063b..b17549668b009efc2951cb0e3e23e8862fdc3db8 100644 (file)
@@ -128,6 +128,9 @@ class DefaultDialect(interfaces.Dialect):
 
     supports_server_side_cursors = False
 
+    # extra record-level locking features (#4860)
+    supports_for_update_of = False
+
     server_version_info = None
 
     default_schema_name = None
index 4e6199c6f2bfb075c9de2239081bf3d12722a070..167460cba07bf443d9a737b7b124e0128d3a6bd4 100644 (file)
@@ -353,21 +353,30 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
         expr = literal("x", type_=String) + literal("y", type_=String)
         self.assert_compile(expr, "concat('x', 'y')", literal_binds=True)
 
-    def test_for_update(self):
+    def test_mariadb_for_update(self):
+        dialect = mysql.dialect()
+        dialect.server_version_info = (10, 1, 1, "MariaDB")
+
         table1 = table(
             "mytable", column("myid"), column("name"), column("description")
         )
 
         self.assert_compile(
-            table1.select(table1.c.myid == 7).with_for_update(),
+            table1.select(table1.c.myid == 7).with_for_update(of=table1),
             "SELECT mytable.myid, mytable.name, mytable.description "
-            "FROM mytable WHERE mytable.myid = %s FOR UPDATE",
+            "FROM mytable WHERE mytable.myid = %s "
+            "FOR UPDATE",
+            dialect=dialect,
         )
 
         self.assert_compile(
-            table1.select(table1.c.myid == 7).with_for_update(read=True),
+            table1.select(table1.c.myid == 7).with_for_update(
+                skip_locked=True
+            ),
             "SELECT mytable.myid, mytable.name, mytable.description "
-            "FROM mytable WHERE mytable.myid = %s LOCK IN SHARE MODE",
+            "FROM mytable WHERE mytable.myid = %s "
+            "FOR UPDATE",
+            dialect=dialect,
         )
 
     def test_delete_extra_froms(self):
index 5897a094dfe78db48fabc66a9325d5d920257d29..2c247a5c091916f061d2e3423afd27a552077da5 100644 (file)
@@ -11,9 +11,13 @@ from sqlalchemy import ForeignKey
 from sqlalchemy import Integer
 from sqlalchemy import testing
 from sqlalchemy import update
+from sqlalchemy.dialects.mysql import base as mysql
 from sqlalchemy.orm import joinedload
 from sqlalchemy.orm import relationship
 from sqlalchemy.orm import Session
+from sqlalchemy.sql import column
+from sqlalchemy.sql import table
+from sqlalchemy.testing import AssertsCompiledSQL
 from sqlalchemy.testing import fixtures
 
 
@@ -160,3 +164,196 @@ class MySQLForUpdateLockingTest(fixtures.DeclarativeMappedTest):
             # no subquery, should be locked
             self._assert_a_is_locked(True)
             self._assert_b_is_locked(True)
+
+
+class MySQLForUpdateCompileTest(fixtures.TestBase, AssertsCompiledSQL):
+    __dialect__ = mysql.dialect()
+
+    table1 = table(
+        "mytable", column("myid"), column("name"), column("description")
+    )
+    table2 = table("table2", column("mytable_id"))
+    join = table2.join(table1, table2.c.mytable_id == table1.c.myid)
+    for_update_of_dialect = mysql.dialect()
+    for_update_of_dialect.server_version_info = (8, 0, 0)
+    for_update_of_dialect.supports_for_update_of = True
+
+    def test_for_update_basic(self):
+        self.assert_compile(
+            self.table1.select(self.table1.c.myid == 7).with_for_update(),
+            "SELECT mytable.myid, mytable.name, mytable.description "
+            "FROM mytable WHERE mytable.myid = %s FOR UPDATE",
+        )
+
+    def test_for_update_read(self):
+        self.assert_compile(
+            self.table1.select(self.table1.c.myid == 7).with_for_update(
+                read=True
+            ),
+            "SELECT mytable.myid, mytable.name, mytable.description "
+            "FROM mytable WHERE mytable.myid = %s LOCK IN SHARE MODE",
+        )
+
+    def test_for_update_skip_locked(self):
+        self.assert_compile(
+            self.table1.select(self.table1.c.myid == 7).with_for_update(
+                skip_locked=True
+            ),
+            "SELECT mytable.myid, mytable.name, mytable.description "
+            "FROM mytable WHERE mytable.myid = %s "
+            "FOR UPDATE SKIP LOCKED",
+        )
+
+    def test_for_update_read_and_skip_locked(self):
+        self.assert_compile(
+            self.table1.select(self.table1.c.myid == 7).with_for_update(
+                read=True, skip_locked=True
+            ),
+            "SELECT mytable.myid, mytable.name, mytable.description "
+            "FROM mytable WHERE mytable.myid = %s "
+            "LOCK IN SHARE MODE SKIP LOCKED",
+        )
+
+    def test_for_update_nowait(self):
+        self.assert_compile(
+            self.table1.select(self.table1.c.myid == 7).with_for_update(
+                nowait=True
+            ),
+            "SELECT mytable.myid, mytable.name, mytable.description "
+            "FROM mytable WHERE mytable.myid = %s "
+            "FOR UPDATE NOWAIT",
+        )
+
+    def test_for_update_read_and_nowait(self):
+        self.assert_compile(
+            self.table1.select(self.table1.c.myid == 7).with_for_update(
+                read=True, nowait=True
+            ),
+            "SELECT mytable.myid, mytable.name, mytable.description "
+            "FROM mytable WHERE mytable.myid = %s "
+            "LOCK IN SHARE MODE NOWAIT",
+        )
+
+    def test_for_update_of_nowait(self):
+        self.assert_compile(
+            self.table1.select(self.table1.c.myid == 7).with_for_update(
+                of=self.table1, nowait=True
+            ),
+            "SELECT mytable.myid, mytable.name, mytable.description "
+            "FROM mytable WHERE mytable.myid = %s "
+            "FOR UPDATE OF mytable NOWAIT",
+            dialect=self.for_update_of_dialect,
+        )
+
+    def test_for_update_of_basic(self):
+        self.assert_compile(
+            self.table1.select(self.table1.c.myid == 7).with_for_update(
+                of=self.table1
+            ),
+            "SELECT mytable.myid, mytable.name, mytable.description "
+            "FROM mytable WHERE mytable.myid = %s "
+            "FOR UPDATE OF mytable",
+            dialect=self.for_update_of_dialect,
+        )
+
+    def test_for_update_of_skip_locked(self):
+        self.assert_compile(
+            self.table1.select(self.table1.c.myid == 7).with_for_update(
+                of=self.table1, skip_locked=True
+            ),
+            "SELECT mytable.myid, mytable.name, mytable.description "
+            "FROM mytable WHERE mytable.myid = %s "
+            "FOR UPDATE OF mytable SKIP LOCKED",
+            dialect=self.for_update_of_dialect,
+        )
+
+    def test_for_update_of_join_one(self):
+        self.assert_compile(
+            self.join.select(self.table2.c.mytable_id == 7).with_for_update(
+                of=[self.join]
+            ),
+            "SELECT table2.mytable_id, "
+            "mytable.myid, mytable.name, mytable.description "
+            "FROM table2 "
+            "INNER JOIN mytable ON table2.mytable_id = mytable.myid "
+            "WHERE table2.mytable_id = %s "
+            "FOR UPDATE OF mytable, table2",
+            dialect=self.for_update_of_dialect,
+        )
+
+    def test_for_update_of_column_list_aliased(self):
+        ta = self.table1.alias()
+        self.assert_compile(
+            ta.select(ta.c.myid == 7).with_for_update(
+                of=[ta.c.myid, ta.c.name]
+            ),
+            "SELECT mytable_1.myid, mytable_1.name, mytable_1.description "
+            "FROM mytable AS mytable_1 "
+            "WHERE mytable_1.myid = %s FOR UPDATE OF mytable_1",
+            dialect=self.for_update_of_dialect,
+        )
+
+    def test_for_update_of_join_aliased(self):
+        ta = self.table1.alias()
+        alias_join = self.table2.join(
+            ta, self.table2.c.mytable_id == ta.c.myid
+        )
+        self.assert_compile(
+            alias_join.select(self.table2.c.mytable_id == 7).with_for_update(
+                of=[alias_join]
+            ),
+            "SELECT table2.mytable_id, "
+            "mytable_1.myid, mytable_1.name, mytable_1.description "
+            "FROM table2 "
+            "INNER JOIN mytable AS mytable_1 "
+            "ON table2.mytable_id = mytable_1.myid "
+            "WHERE table2.mytable_id = %s "
+            "FOR UPDATE OF mytable_1, table2",
+            dialect=self.for_update_of_dialect,
+        )
+
+    def test_for_update_of_read_nowait(self):
+        self.assert_compile(
+            self.table1.select(self.table1.c.myid == 7).with_for_update(
+                read=True, of=self.table1, nowait=True
+            ),
+            "SELECT mytable.myid, mytable.name, mytable.description "
+            "FROM mytable WHERE mytable.myid = %s "
+            "LOCK IN SHARE MODE OF mytable NOWAIT",
+            dialect=self.for_update_of_dialect,
+        )
+
+    def test_for_update_of_read_skip_locked(self):
+        self.assert_compile(
+            self.table1.select(self.table1.c.myid == 7).with_for_update(
+                read=True, of=self.table1, skip_locked=True
+            ),
+            "SELECT mytable.myid, mytable.name, mytable.description "
+            "FROM mytable WHERE mytable.myid = %s "
+            "LOCK IN SHARE MODE OF mytable SKIP LOCKED",
+            dialect=self.for_update_of_dialect,
+        )
+
+    def test_for_update_of_read_nowait_column_list(self):
+        self.assert_compile(
+            self.table1.select(self.table1.c.myid == 7).with_for_update(
+                read=True,
+                of=[self.table1.c.myid, self.table1.c.name],
+                nowait=True,
+            ),
+            "SELECT mytable.myid, mytable.name, mytable.description "
+            "FROM mytable WHERE mytable.myid = %s "
+            "LOCK IN SHARE MODE OF mytable NOWAIT",
+            dialect=self.for_update_of_dialect,
+        )
+
+    def test_for_update_of_read(self):
+        self.assert_compile(
+            self.table1.select(self.table1.c.myid == 7).with_for_update(
+                read=True, of=self.table1
+            ),
+            "SELECT mytable.myid, mytable.name, mytable.description "
+            "FROM mytable WHERE mytable.myid = %s "
+            "LOCK IN SHARE MODE OF mytable",
+            dialect=self.for_update_of_dialect,
+        )
index 4cc9c837d6eacf1960c0dc1d51135b5d19c025ca..c707137a819c163785ddcfe4c07d8c061d954032 100644 (file)
@@ -950,6 +950,24 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
             "FROM mytable WHERE mytable.myid = %(myid_1)s FOR SHARE NOWAIT",
         )
 
+        self.assert_compile(
+            table1.select(table1.c.myid == 7).with_for_update(
+                key_share=True, nowait=True
+            ),
+            "SELECT mytable.myid, mytable.name, mytable.description "
+            "FROM mytable WHERE mytable.myid = %(myid_1)s "
+            "FOR NO KEY UPDATE NOWAIT",
+        )
+
+        self.assert_compile(
+            table1.select(table1.c.myid == 7).with_for_update(
+                key_share=True, read=True, nowait=True
+            ),
+            "SELECT mytable.myid, mytable.name, mytable.description "
+            "FROM mytable WHERE mytable.myid = %(myid_1)s "
+            "FOR KEY SHARE NOWAIT",
+        )
+
         self.assert_compile(
             table1.select(table1.c.myid == 7).with_for_update(
                 read=True, skip_locked=True
@@ -977,6 +995,15 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
             "FOR SHARE OF mytable NOWAIT",
         )
 
+        self.assert_compile(
+            table1.select(table1.c.myid == 7).with_for_update(
+                key_share=True, read=True, nowait=True, of=table1
+            ),
+            "SELECT mytable.myid, mytable.name, mytable.description "
+            "FROM mytable WHERE mytable.myid = %(myid_1)s "
+            "FOR KEY SHARE OF mytable NOWAIT",
+        )
+
         self.assert_compile(
             table1.select(table1.c.myid == 7).with_for_update(
                 read=True, nowait=True, of=table1.c.myid
@@ -995,6 +1022,27 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
             "FOR SHARE OF mytable NOWAIT",
         )
 
+        self.assert_compile(
+            table1.select(table1.c.myid == 7).with_for_update(
+                read=True,
+                skip_locked=True,
+                of=[table1.c.myid, table1.c.name],
+                key_share=True,
+            ),
+            "SELECT mytable.myid, mytable.name, mytable.description "
+            "FROM mytable WHERE mytable.myid = %(myid_1)s "
+            "FOR KEY SHARE OF mytable SKIP LOCKED",
+        )
+
+        self.assert_compile(
+            table1.select(table1.c.myid == 7).with_for_update(
+                skip_locked=True, of=[table1.c.myid, table1.c.name]
+            ),
+            "SELECT mytable.myid, mytable.name, mytable.description "
+            "FROM mytable WHERE mytable.myid = %(myid_1)s "
+            "FOR UPDATE OF mytable SKIP LOCKED",
+        )
+
         self.assert_compile(
             table1.select(table1.c.myid == 7).with_for_update(
                 read=True, skip_locked=True, of=[table1.c.myid, table1.c.name]
@@ -1058,6 +1106,15 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
             "FOR KEY SHARE OF mytable",
         )
 
+        self.assert_compile(
+            table1.select(table1.c.myid == 7).with_for_update(
+                read=True, of=table1
+            ),
+            "SELECT mytable.myid, mytable.name, mytable.description "
+            "FROM mytable WHERE mytable.myid = %(myid_1)s "
+            "FOR SHARE OF mytable",
+        )
+
         self.assert_compile(
             table1.select(table1.c.myid == 7).with_for_update(
                 read=True, key_share=True, skip_locked=True
@@ -1067,6 +1124,15 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
             "FOR KEY SHARE SKIP LOCKED",
         )
 
+        self.assert_compile(
+            table1.select(table1.c.myid == 7).with_for_update(
+                key_share=True, skip_locked=True
+            ),
+            "SELECT mytable.myid, mytable.name, mytable.description "
+            "FROM mytable WHERE mytable.myid = %(myid_1)s "
+            "FOR NO KEY UPDATE SKIP LOCKED",
+        )
+
         ta = table1.alias()
         self.assert_compile(
             ta.select(ta.c.myid == 7).with_for_update(
index cf9168f5a23d184f827426f72d112e1cc2603e39..d0560f579ea51047996ce4dae7065c73e8b6dae0 100644 (file)
@@ -1616,3 +1616,7 @@ class DefaultRequirements(SuiteRequirements):
     def supports_distinct_on(self):
         """If a backend supports the DISTINCT ON in a select"""
         return only_if(["postgresql"])
+
+    @property
+    def supports_for_update_of(self):
+        return only_if(lambda config: config.db.dialect.supports_for_update_of)