]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Fixes: #4860 Add SKIP LOCKED, OF, NOWAIT for mysql
authorRobotScribe <quentinso@theodo.fr>
Wed, 29 Apr 2020 10:34:57 +0000 (12:34 +0200)
committerRobotScribe <quentinso@theodo.fr>
Wed, 29 Apr 2020 12:45:54 +0000 (14:45 +0200)
lib/sqlalchemy/dialects/mysql/base.py
test/dialect/mysql/test_compiler.py

index 38f3fa6111aa138079d0ed7e0d03129d7ce72a09..a814de1b6d290aa758202dbc6c468b6d493b88e2 100644 (file)
@@ -808,6 +808,7 @@ from ...sql import coercions
 from ...sql import compiler
 from ...sql import elements
 from ...sql import roles
+from ...sql import util as sql_util
 from ...types import BINARY
 from ...types import BLOB
 from ...types import BOOLEAN
@@ -1494,9 +1495,27 @@ class MySQLCompiler(compiler.SQLCompiler):
 
     def for_update_clause(self, select, **kw):
         if select._for_update_arg.read:
-            return " LOCK IN SHARE MODE"
+            tmp = " LOCK IN SHARE MODE"
         else:
-            return " FOR UPDATE"
+            tmp = " FOR UPDATE"
+
+        if select._for_update_arg.of:
+
+            tables = util.OrderedSet()
+            for c in select._for_update_arg.of:
+                tables.update(sql_util.surface_selectables_only(c))
+
+            tmp += " OF " + ", ".join(
+                self.process(table, ashint=True, use_schema=False, **kw)
+                for table in tables
+            )
+
+        if select._for_update_arg.nowait:
+            tmp += " NOWAIT"
+        if select._for_update_arg.skip_locked:
+            tmp += " SKIP LOCKED"
+
+        return tmp
 
     def limit_clause(self, select, **kw):
         # MySQL supports:
index 4e6199c6f2bfb075c9de2239081bf3d12722a070..68dbd61329bcf7e36b12f08a31a0db8bed17f8d5 100644 (file)
@@ -370,6 +370,148 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
             "FROM mytable WHERE mytable.myid = %s LOCK IN SHARE MODE",
         )
 
+        self.assert_compile(
+            table1.select(table1.c.myid == 7).with_for_update(
+                of=table1
+            ),
+            "SELECT mytable.myid, mytable.name, mytable.description "
+            "FROM mytable WHERE mytable.myid = %s "
+            "FOR UPDATE OF mytable",
+        )
+
+        self.assert_compile(
+            table1.select(table1.c.myid == 7).with_for_update(
+                read=True, of=table1
+            ),
+            "SELECT mytable.myid, mytable.name, mytable.description "
+            "FROM mytable WHERE mytable.myid = %s "
+            "LOCK IN SHARE MODE OF mytable",
+        )
+
+        self.assert_compile(
+            table1.select(table1.c.myid == 7).with_for_update(
+                skip_locked=True
+            ),
+            "SELECT mytable.myid, mytable.name, mytable.description "
+            "FROM mytable WHERE mytable.myid = %s "
+            "FOR UPDATE SKIP LOCKED",
+        )
+
+        self.assert_compile(
+            table1.select(table1.c.myid == 7).with_for_update(
+                read=True, skip_locked=True
+            ),
+            "SELECT mytable.myid, mytable.name, mytable.description "
+            "FROM mytable WHERE mytable.myid = %s "
+            "LOCK IN SHARE MODE SKIP LOCKED",
+        )
+
+        self.assert_compile(
+            table1.select(table1.c.myid == 7).with_for_update(
+                of=table1, skip_locked=True
+            ),
+            "SELECT mytable.myid, mytable.name, mytable.description "
+            "FROM mytable WHERE mytable.myid = %s "
+            "FOR UPDATE OF mytable SKIP LOCKED",
+        )
+
+        self.assert_compile(
+            table1.select(table1.c.myid == 7).with_for_update(
+                read=True, of=table1, skip_locked=True
+            ),
+            "SELECT mytable.myid, mytable.name, mytable.description "
+            "FROM mytable WHERE mytable.myid = %s "
+            "LOCK IN SHARE MODE OF mytable SKIP LOCKED",
+        )
+
+        self.assert_compile(
+            table1.select(table1.c.myid == 7).with_for_update(
+                nowait=True
+            ),
+            "SELECT mytable.myid, mytable.name, mytable.description "
+            "FROM mytable WHERE mytable.myid = %s "
+            "FOR UPDATE NOWAIT",
+        )
+
+        self.assert_compile(
+            table1.select(table1.c.myid == 7).with_for_update(
+                read=True, nowait=True
+            ),
+            "SELECT mytable.myid, mytable.name, mytable.description "
+            "FROM mytable WHERE mytable.myid = %s "
+            "LOCK IN SHARE MODE NOWAIT",
+        )
+
+        self.assert_compile(
+            table1.select(table1.c.myid == 7).with_for_update(
+                of=table1, nowait=True
+            ),
+            "SELECT mytable.myid, mytable.name, mytable.description "
+            "FROM mytable WHERE mytable.myid = %s "
+            "FOR UPDATE OF mytable NOWAIT",
+        )
+
+        self.assert_compile(
+            table1.select(table1.c.myid == 7).with_for_update(
+                read=True, of=table1, nowait=True
+            ),
+            "SELECT mytable.myid, mytable.name, mytable.description "
+            "FROM mytable WHERE mytable.myid = %s "
+            "LOCK IN SHARE MODE OF mytable NOWAIT",
+        )
+
+        self.assert_compile(
+            table1.select(table1.c.myid == 7).with_for_update(
+                read=True, of=[table1.c.myid, table1.c.name], nowait=True
+            ),
+            "SELECT mytable.myid, mytable.name, mytable.description "
+            "FROM mytable WHERE mytable.myid = %s "
+            "LOCK IN SHARE MODE OF mytable NOWAIT",
+        )
+
+        self.assert_compile(
+            table1.select(table1.c.myid == 7).with_for_update(
+                read=True, of=table1, nowait=True
+            ),
+            "SELECT mytable.myid, mytable.name, mytable.description "
+            "FROM mytable WHERE mytable.myid = %s "
+            "LOCK IN SHARE MODE OF mytable NOWAIT",
+        )
+
+        ta = table1.alias()
+        self.assert_compile(
+            ta.select(ta.c.myid == 7).with_for_update(
+                of=[ta.c.myid, ta.c.name]
+            ),
+            "SELECT mytable_1.myid, mytable_1.name, mytable_1.description "
+            "FROM mytable AS mytable_1 "
+            "WHERE mytable_1.myid = %s FOR UPDATE OF mytable_1",
+        )
+
+        table2 = table("table2", column("mytable_id"))
+        join = table2.join(table1, table2.c.mytable_id == table1.c.myid)
+        self.assert_compile(
+            join.select(table2.c.mytable_id == 7).with_for_update(of=[join]),
+            "SELECT table2.mytable_id, "
+            "mytable.myid, mytable.name, mytable.description "
+            "FROM table2 "
+            "INNER JOIN mytable ON table2.mytable_id = mytable.myid "
+            "WHERE table2.mytable_id = %s "
+            "FOR UPDATE OF mytable, table2",
+        )
+
+        join = table2.join(ta, table2.c.mytable_id == ta.c.myid)
+        self.assert_compile(
+            join.select(table2.c.mytable_id == 7).with_for_update(of=[join]),
+            "SELECT table2.mytable_id, "
+            "mytable_1.myid, mytable_1.name, mytable_1.description "
+            "FROM table2 "
+            "INNER JOIN mytable AS mytable_1 "
+            "ON table2.mytable_id = mytable_1.myid "
+            "WHERE table2.mytable_id = %s "
+            "FOR UPDATE OF mytable_1, table2",
+        )
+
     def test_delete_extra_froms(self):
         t1 = table("t1", column("c1"))
         t2 = table("t2", column("c1"))