]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Add with_for_update mysql new functionalities
authorRobotScribe <quentinso@theodo.fr>
Wed, 29 Apr 2020 19:22:59 +0000 (15:22 -0400)
committerGord Thompson <gord@gordthompson.com>
Sun, 24 May 2020 11:32:32 +0000 (05:32 -0600)
Fixes: #4860
# Description
Add nowait, skip_lock, of arguments to for_update_clause for mysql

### Checklist

This pull request is:

- [ ] A documentation / typographical error fix
- Good to go, no issue or tests are needed
- [ ] A short code fix
- please include the issue number, and create an issue if none exists, which
  must include a complete example of the issue.  one line code fixes without an
  issue and demonstration will not be accepted.
- Please include: `Fixes: #<issue number>` in the commit message
- please include tests.   one line code fixes without tests will not be accepted.
- [x] A new feature implementation
- please include the issue number, and create an issue if none exists, which must
  include a complete example of how the feature would look.
- Please include: `Fixes: #<issue number>` in the commit message
- please include tests.

**Have a nice day!**

Closes: #5290
Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/5290
Pull-request-sha: 490e822e73e92ffe63cf45df9c49f3b31af1954d

Change-Id: Ibd2acc47b538c601c69c8fb954776035ecab4c6c
(cherry picked from commit 103260ddb476c5354b3201f92636c474f2a83c35)

doc/build/changelog/unreleased_13/4860.rst [new file with mode: 0644]
lib/sqlalchemy/dialects/mysql/base.py
lib/sqlalchemy/engine/default.py
test/dialect/mysql/test_compiler.py
test/dialect/mysql/test_for_update.py
test/dialect/postgresql/test_compiler.py
test/requirements.py

diff --git a/doc/build/changelog/unreleased_13/4860.rst b/doc/build/changelog/unreleased_13/4860.rst
new file mode 100644 (file)
index 0000000..b526ce3
--- /dev/null
@@ -0,0 +1,6 @@
+.. change::
+    :tags: usecase, mysql
+    :tickets: 4860
+
+    Implemented row-level locking support for mysql.  Pull request courtesy
+    Quentin Somerville.
\ No newline at end of file
index cda85c0db2596305895573d948b595ea69511596..12156a1de67108f51e52790347c78a3395a3e6af 100644 (file)
@@ -806,6 +806,7 @@ from ...engine import default
 from ...engine import reflection
 from ...sql import compiler
 from ...sql import elements
+from ...sql import util as sql_util
 from ...types import BINARY
 from ...types import BLOB
 from ...types import BOOLEAN
@@ -1470,9 +1471,28 @@ class MySQLCompiler(compiler.SQLCompiler):
 
     def for_update_clause(self, select, **kw):
         if select._for_update_arg.read:
-            return " LOCK IN SHARE MODE"
+            tmp = " LOCK IN SHARE MODE"
         else:
-            return " FOR UPDATE"
+            tmp = " FOR UPDATE"
+
+        if select._for_update_arg.of and self.dialect.supports_for_update_of:
+
+            tables = util.OrderedSet()
+            for c in select._for_update_arg.of:
+                tables.update(sql_util.surface_selectables_only(c))
+
+            tmp += " OF " + ", ".join(
+                self.process(table, ashint=True, use_schema=False, **kw)
+                for table in tables
+            )
+
+        if select._for_update_arg.nowait:
+            tmp += " NOWAIT"
+
+        if select._for_update_arg.skip_locked and self.dialect._is_mysql:
+            tmp += " SKIP LOCKED"
+
+        return tmp
 
     def limit_clause(self, select, **kw):
         # MySQL supports:
@@ -2186,6 +2206,9 @@ class MySQLDialect(default.DefaultDialect):
 
     supports_native_enum = True
 
+    supports_for_update_of = False  # default for MySQL ...
+    # ... may be updated to True for MySQL 8+ in initialize()
+
     supports_sane_rowcount = True
     supports_sane_multi_rowcount = False
     supports_multivalues_insert = True
@@ -2482,6 +2505,10 @@ class MySQLDialect(default.DefaultDialect):
 
         default.DefaultDialect.initialize(self, connection)
 
+        self.supports_for_update_of = (
+            self._is_mysql and self.server_version_info >= (8,)
+        )
+
         self._needs_correct_for_88718_96365 = (
             not self._is_mariadb and self.server_version_info >= (8,)
         )
index 51977f880e40794c42c7e250cd6164e4a10bc614..06f7a86186df70620c17cd0d6e0bb821d64b31d0 100644 (file)
@@ -133,6 +133,9 @@ class DefaultDialect(interfaces.Dialect):
 
     supports_server_side_cursors = False
 
+    # extra record-level locking features (#4860)
+    supports_for_update_of = False
+
     server_version_info = None
 
     construct_arguments = None
index e74c37d63da69d440bbcafca446eca0de9062712..ec3c8bc13952c4490282d17baa73e83f65588e35 100644 (file)
@@ -353,21 +353,30 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
         expr = literal("x", type_=String) + literal("y", type_=String)
         self.assert_compile(expr, "concat('x', 'y')", literal_binds=True)
 
-    def test_for_update(self):
+    def test_mariadb_for_update(self):
+        dialect = mysql.dialect()
+        dialect.server_version_info = (10, 1, 1, "MariaDB")
+
         table1 = table(
             "mytable", column("myid"), column("name"), column("description")
         )
 
         self.assert_compile(
-            table1.select(table1.c.myid == 7).with_for_update(),
+            table1.select(table1.c.myid == 7).with_for_update(of=table1),
             "SELECT mytable.myid, mytable.name, mytable.description "
-            "FROM mytable WHERE mytable.myid = %s FOR UPDATE",
+            "FROM mytable WHERE mytable.myid = %s "
+            "FOR UPDATE",
+            dialect=dialect,
         )
 
         self.assert_compile(
-            table1.select(table1.c.myid == 7).with_for_update(read=True),
+            table1.select(table1.c.myid == 7).with_for_update(
+                skip_locked=True
+            ),
             "SELECT mytable.myid, mytable.name, mytable.description "
-            "FROM mytable WHERE mytable.myid = %s LOCK IN SHARE MODE",
+            "FROM mytable WHERE mytable.myid = %s "
+            "FOR UPDATE",
+            dialect=dialect,
         )
 
     def test_delete_extra_froms(self):
index 948be0797719b2aaa370f8f931ace847e0729f32..2d672cb3dd60bf3dfc59fda3a8a3cdebafe4c94e 100644 (file)
@@ -11,9 +11,13 @@ from sqlalchemy import ForeignKey
 from sqlalchemy import Integer
 from sqlalchemy import testing
 from sqlalchemy import update
+from sqlalchemy.dialects.mysql import base as mysql
 from sqlalchemy.orm import joinedload
 from sqlalchemy.orm import relationship
 from sqlalchemy.orm import Session
+from sqlalchemy.sql import column
+from sqlalchemy.sql import table
+from sqlalchemy.testing import AssertsCompiledSQL
 from sqlalchemy.testing import fixtures
 
 
@@ -160,3 +164,196 @@ class MySQLForUpdateLockingTest(fixtures.DeclarativeMappedTest):
             # no subquery, should be locked
             self._assert_a_is_locked(True)
             self._assert_b_is_locked(True)
+
+
+class MySQLForUpdateCompileTest(fixtures.TestBase, AssertsCompiledSQL):
+    __dialect__ = mysql.dialect()
+
+    table1 = table(
+        "mytable", column("myid"), column("name"), column("description")
+    )
+    table2 = table("table2", column("mytable_id"))
+    join = table2.join(table1, table2.c.mytable_id == table1.c.myid)
+    for_update_of_dialect = mysql.dialect()
+    for_update_of_dialect.server_version_info = (8, 0, 0)
+    for_update_of_dialect.supports_for_update_of = True
+
+    def test_for_update_basic(self):
+        self.assert_compile(
+            self.table1.select(self.table1.c.myid == 7).with_for_update(),
+            "SELECT mytable.myid, mytable.name, mytable.description "
+            "FROM mytable WHERE mytable.myid = %s FOR UPDATE",
+        )
+
+    def test_for_update_read(self):
+        self.assert_compile(
+            self.table1.select(self.table1.c.myid == 7).with_for_update(
+                read=True
+            ),
+            "SELECT mytable.myid, mytable.name, mytable.description "
+            "FROM mytable WHERE mytable.myid = %s LOCK IN SHARE MODE",
+        )
+
+    def test_for_update_skip_locked(self):
+        self.assert_compile(
+            self.table1.select(self.table1.c.myid == 7).with_for_update(
+                skip_locked=True
+            ),
+            "SELECT mytable.myid, mytable.name, mytable.description "
+            "FROM mytable WHERE mytable.myid = %s "
+            "FOR UPDATE SKIP LOCKED",
+        )
+
+    def test_for_update_read_and_skip_locked(self):
+        self.assert_compile(
+            self.table1.select(self.table1.c.myid == 7).with_for_update(
+                read=True, skip_locked=True
+            ),
+            "SELECT mytable.myid, mytable.name, mytable.description "
+            "FROM mytable WHERE mytable.myid = %s "
+            "LOCK IN SHARE MODE SKIP LOCKED",
+        )
+
+    def test_for_update_nowait(self):
+        self.assert_compile(
+            self.table1.select(self.table1.c.myid == 7).with_for_update(
+                nowait=True
+            ),
+            "SELECT mytable.myid, mytable.name, mytable.description "
+            "FROM mytable WHERE mytable.myid = %s "
+            "FOR UPDATE NOWAIT",
+        )
+
+    def test_for_update_read_and_nowait(self):
+        self.assert_compile(
+            self.table1.select(self.table1.c.myid == 7).with_for_update(
+                read=True, nowait=True
+            ),
+            "SELECT mytable.myid, mytable.name, mytable.description "
+            "FROM mytable WHERE mytable.myid = %s "
+            "LOCK IN SHARE MODE NOWAIT",
+        )
+
+    def test_for_update_of_nowait(self):
+        self.assert_compile(
+            self.table1.select(self.table1.c.myid == 7).with_for_update(
+                of=self.table1, nowait=True
+            ),
+            "SELECT mytable.myid, mytable.name, mytable.description "
+            "FROM mytable WHERE mytable.myid = %s "
+            "FOR UPDATE OF mytable NOWAIT",
+            dialect=self.for_update_of_dialect,
+        )
+
+    def test_for_update_of_basic(self):
+        self.assert_compile(
+            self.table1.select(self.table1.c.myid == 7).with_for_update(
+                of=self.table1
+            ),
+            "SELECT mytable.myid, mytable.name, mytable.description "
+            "FROM mytable WHERE mytable.myid = %s "
+            "FOR UPDATE OF mytable",
+            dialect=self.for_update_of_dialect,
+        )
+
+    def test_for_update_of_skip_locked(self):
+        self.assert_compile(
+            self.table1.select(self.table1.c.myid == 7).with_for_update(
+                of=self.table1, skip_locked=True
+            ),
+            "SELECT mytable.myid, mytable.name, mytable.description "
+            "FROM mytable WHERE mytable.myid = %s "
+            "FOR UPDATE OF mytable SKIP LOCKED",
+            dialect=self.for_update_of_dialect,
+        )
+
+    def test_for_update_of_join_one(self):
+        self.assert_compile(
+            self.join.select(self.table2.c.mytable_id == 7).with_for_update(
+                of=[self.join]
+            ),
+            "SELECT table2.mytable_id, "
+            "mytable.myid, mytable.name, mytable.description "
+            "FROM table2 "
+            "INNER JOIN mytable ON table2.mytable_id = mytable.myid "
+            "WHERE table2.mytable_id = %s "
+            "FOR UPDATE OF mytable, table2",
+            dialect=self.for_update_of_dialect,
+        )
+
+    def test_for_update_of_column_list_aliased(self):
+        ta = self.table1.alias()
+        self.assert_compile(
+            ta.select(ta.c.myid == 7).with_for_update(
+                of=[ta.c.myid, ta.c.name]
+            ),
+            "SELECT mytable_1.myid, mytable_1.name, mytable_1.description "
+            "FROM mytable AS mytable_1 "
+            "WHERE mytable_1.myid = %s FOR UPDATE OF mytable_1",
+            dialect=self.for_update_of_dialect,
+        )
+
+    def test_for_update_of_join_aliased(self):
+        ta = self.table1.alias()
+        alias_join = self.table2.join(
+            ta, self.table2.c.mytable_id == ta.c.myid
+        )
+        self.assert_compile(
+            alias_join.select(self.table2.c.mytable_id == 7).with_for_update(
+                of=[alias_join]
+            ),
+            "SELECT table2.mytable_id, "
+            "mytable_1.myid, mytable_1.name, mytable_1.description "
+            "FROM table2 "
+            "INNER JOIN mytable AS mytable_1 "
+            "ON table2.mytable_id = mytable_1.myid "
+            "WHERE table2.mytable_id = %s "
+            "FOR UPDATE OF mytable_1, table2",
+            dialect=self.for_update_of_dialect,
+        )
+
+    def test_for_update_of_read_nowait(self):
+        self.assert_compile(
+            self.table1.select(self.table1.c.myid == 7).with_for_update(
+                read=True, of=self.table1, nowait=True
+            ),
+            "SELECT mytable.myid, mytable.name, mytable.description "
+            "FROM mytable WHERE mytable.myid = %s "
+            "LOCK IN SHARE MODE OF mytable NOWAIT",
+            dialect=self.for_update_of_dialect,
+        )
+
+    def test_for_update_of_read_skip_locked(self):
+        self.assert_compile(
+            self.table1.select(self.table1.c.myid == 7).with_for_update(
+                read=True, of=self.table1, skip_locked=True
+            ),
+            "SELECT mytable.myid, mytable.name, mytable.description "
+            "FROM mytable WHERE mytable.myid = %s "
+            "LOCK IN SHARE MODE OF mytable SKIP LOCKED",
+            dialect=self.for_update_of_dialect,
+        )
+
+    def test_for_update_of_read_nowait_column_list(self):
+        self.assert_compile(
+            self.table1.select(self.table1.c.myid == 7).with_for_update(
+                read=True,
+                of=[self.table1.c.myid, self.table1.c.name],
+                nowait=True,
+            ),
+            "SELECT mytable.myid, mytable.name, mytable.description "
+            "FROM mytable WHERE mytable.myid = %s "
+            "LOCK IN SHARE MODE OF mytable NOWAIT",
+            dialect=self.for_update_of_dialect,
+        )
+
+    def test_for_update_of_read(self):
+        self.assert_compile(
+            self.table1.select(self.table1.c.myid == 7).with_for_update(
+                read=True, of=self.table1
+            ),
+            "SELECT mytable.myid, mytable.name, mytable.description "
+            "FROM mytable WHERE mytable.myid = %s "
+            "LOCK IN SHARE MODE OF mytable",
+            dialect=self.for_update_of_dialect,
+        )
index aabbc3ac3b0aebbeac79b85e4b1201e1702c410f..57099f755c56d847fc516449ab92779ba15aed75 100644 (file)
@@ -946,6 +946,24 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
             "FROM mytable WHERE mytable.myid = %(myid_1)s FOR SHARE NOWAIT",
         )
 
+        self.assert_compile(
+            table1.select(table1.c.myid == 7).with_for_update(
+                key_share=True, nowait=True
+            ),
+            "SELECT mytable.myid, mytable.name, mytable.description "
+            "FROM mytable WHERE mytable.myid = %(myid_1)s "
+            "FOR NO KEY UPDATE NOWAIT",
+        )
+
+        self.assert_compile(
+            table1.select(table1.c.myid == 7).with_for_update(
+                key_share=True, read=True, nowait=True
+            ),
+            "SELECT mytable.myid, mytable.name, mytable.description "
+            "FROM mytable WHERE mytable.myid = %(myid_1)s "
+            "FOR KEY SHARE NOWAIT",
+        )
+
         self.assert_compile(
             table1.select(table1.c.myid == 7).with_for_update(
                 read=True, skip_locked=True
@@ -973,6 +991,15 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
             "FOR SHARE OF mytable NOWAIT",
         )
 
+        self.assert_compile(
+            table1.select(table1.c.myid == 7).with_for_update(
+                key_share=True, read=True, nowait=True, of=table1
+            ),
+            "SELECT mytable.myid, mytable.name, mytable.description "
+            "FROM mytable WHERE mytable.myid = %(myid_1)s "
+            "FOR KEY SHARE OF mytable NOWAIT",
+        )
+
         self.assert_compile(
             table1.select(table1.c.myid == 7).with_for_update(
                 read=True, nowait=True, of=table1.c.myid
@@ -991,6 +1018,27 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
             "FOR SHARE OF mytable NOWAIT",
         )
 
+        self.assert_compile(
+            table1.select(table1.c.myid == 7).with_for_update(
+                read=True,
+                skip_locked=True,
+                of=[table1.c.myid, table1.c.name],
+                key_share=True,
+            ),
+            "SELECT mytable.myid, mytable.name, mytable.description "
+            "FROM mytable WHERE mytable.myid = %(myid_1)s "
+            "FOR KEY SHARE OF mytable SKIP LOCKED",
+        )
+
+        self.assert_compile(
+            table1.select(table1.c.myid == 7).with_for_update(
+                skip_locked=True, of=[table1.c.myid, table1.c.name]
+            ),
+            "SELECT mytable.myid, mytable.name, mytable.description "
+            "FROM mytable WHERE mytable.myid = %(myid_1)s "
+            "FOR UPDATE OF mytable SKIP LOCKED",
+        )
+
         self.assert_compile(
             table1.select(table1.c.myid == 7).with_for_update(
                 read=True, skip_locked=True, of=[table1.c.myid, table1.c.name]
@@ -1054,6 +1102,15 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
             "FOR KEY SHARE OF mytable",
         )
 
+        self.assert_compile(
+            table1.select(table1.c.myid == 7).with_for_update(
+                read=True, of=table1
+            ),
+            "SELECT mytable.myid, mytable.name, mytable.description "
+            "FROM mytable WHERE mytable.myid = %(myid_1)s "
+            "FOR SHARE OF mytable",
+        )
+
         self.assert_compile(
             table1.select(table1.c.myid == 7).with_for_update(
                 read=True, key_share=True, skip_locked=True
@@ -1063,6 +1120,15 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
             "FOR KEY SHARE SKIP LOCKED",
         )
 
+        self.assert_compile(
+            table1.select(table1.c.myid == 7).with_for_update(
+                key_share=True, skip_locked=True
+            ),
+            "SELECT mytable.myid, mytable.name, mytable.description "
+            "FROM mytable WHERE mytable.myid = %(myid_1)s "
+            "FOR NO KEY UPDATE SKIP LOCKED",
+        )
+
         ta = table1.alias()
         self.assert_compile(
             ta.select(ta.c.myid == 7).with_for_update(
index 73d8d0e5443f9a9a86398ce20a3c9fcb97c8ddf1..c739253f0215b616b9038a2ef4ee4533d028e39e 100644 (file)
@@ -1549,3 +1549,7 @@ class DefaultRequirements(SuiteRequirements):
     @property
     def computed_columns_reflect_persisted(self):
         return self.computed_columns + skip_if("oracle")
+
+    @property
+    def supports_for_update_of(self):
+        return only_if(lambda config: config.db.dialect.supports_for_update_of)