From 2899538c565509e7a05d6099737534d8a17cb23d Mon Sep 17 00:00:00 2001 From: RobotScribe Date: Wed, 29 Apr 2020 12:34:57 +0200 Subject: [PATCH] Fixes: #4860 Add SKIP LOCKED, OF, NOWAIT for mysql --- lib/sqlalchemy/dialects/mysql/base.py | 23 ++++- test/dialect/mysql/test_compiler.py | 142 ++++++++++++++++++++++++++ 2 files changed, 163 insertions(+), 2 deletions(-) diff --git a/lib/sqlalchemy/dialects/mysql/base.py b/lib/sqlalchemy/dialects/mysql/base.py index 38f3fa6111..a814de1b6d 100644 --- a/lib/sqlalchemy/dialects/mysql/base.py +++ b/lib/sqlalchemy/dialects/mysql/base.py @@ -808,6 +808,7 @@ from ...sql import coercions from ...sql import compiler from ...sql import elements from ...sql import roles +from ...sql import util as sql_util from ...types import BINARY from ...types import BLOB from ...types import BOOLEAN @@ -1494,9 +1495,27 @@ class MySQLCompiler(compiler.SQLCompiler): def for_update_clause(self, select, **kw): if select._for_update_arg.read: - return " LOCK IN SHARE MODE" + tmp = " LOCK IN SHARE MODE" else: - return " FOR UPDATE" + tmp = " FOR UPDATE" + + if select._for_update_arg.of: + + tables = util.OrderedSet() + for c in select._for_update_arg.of: + tables.update(sql_util.surface_selectables_only(c)) + + tmp += " OF " + ", ".join( + self.process(table, ashint=True, use_schema=False, **kw) + for table in tables + ) + + if select._for_update_arg.nowait: + tmp += " NOWAIT" + if select._for_update_arg.skip_locked: + tmp += " SKIP LOCKED" + + return tmp def limit_clause(self, select, **kw): # MySQL supports: diff --git a/test/dialect/mysql/test_compiler.py b/test/dialect/mysql/test_compiler.py index 4e6199c6f2..68dbd61329 100644 --- a/test/dialect/mysql/test_compiler.py +++ b/test/dialect/mysql/test_compiler.py @@ -370,6 +370,148 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): "FROM mytable WHERE mytable.myid = %s LOCK IN SHARE MODE", ) + self.assert_compile( + table1.select(table1.c.myid == 7).with_for_update( + of=table1 + ), + "SELECT mytable.myid, mytable.name, mytable.description " + "FROM mytable WHERE mytable.myid = %s " + "FOR UPDATE OF mytable", + ) + + self.assert_compile( + table1.select(table1.c.myid == 7).with_for_update( + read=True, of=table1 + ), + "SELECT mytable.myid, mytable.name, mytable.description " + "FROM mytable WHERE mytable.myid = %s " + "LOCK IN SHARE MODE OF mytable", + ) + + self.assert_compile( + table1.select(table1.c.myid == 7).with_for_update( + skip_locked=True + ), + "SELECT mytable.myid, mytable.name, mytable.description " + "FROM mytable WHERE mytable.myid = %s " + "FOR UPDATE SKIP LOCKED", + ) + + self.assert_compile( + table1.select(table1.c.myid == 7).with_for_update( + read=True, skip_locked=True + ), + "SELECT mytable.myid, mytable.name, mytable.description " + "FROM mytable WHERE mytable.myid = %s " + "LOCK IN SHARE MODE SKIP LOCKED", + ) + + self.assert_compile( + table1.select(table1.c.myid == 7).with_for_update( + of=table1, skip_locked=True + ), + "SELECT mytable.myid, mytable.name, mytable.description " + "FROM mytable WHERE mytable.myid = %s " + "FOR UPDATE OF mytable SKIP LOCKED", + ) + + self.assert_compile( + table1.select(table1.c.myid == 7).with_for_update( + read=True, of=table1, skip_locked=True + ), + "SELECT mytable.myid, mytable.name, mytable.description " + "FROM mytable WHERE mytable.myid = %s " + "LOCK IN SHARE MODE OF mytable SKIP LOCKED", + ) + + self.assert_compile( + table1.select(table1.c.myid == 7).with_for_update( + nowait=True + ), + "SELECT mytable.myid, mytable.name, mytable.description " + "FROM mytable WHERE mytable.myid = %s " + "FOR UPDATE NOWAIT", + ) + + self.assert_compile( + table1.select(table1.c.myid == 7).with_for_update( + read=True, nowait=True + ), + "SELECT mytable.myid, mytable.name, mytable.description " + "FROM mytable WHERE mytable.myid = %s " + "LOCK IN SHARE MODE NOWAIT", + ) + + self.assert_compile( + table1.select(table1.c.myid == 7).with_for_update( + of=table1, nowait=True + ), + "SELECT mytable.myid, mytable.name, mytable.description " + "FROM mytable WHERE mytable.myid = %s " + "FOR UPDATE OF mytable NOWAIT", + ) + + self.assert_compile( + table1.select(table1.c.myid == 7).with_for_update( + read=True, of=table1, nowait=True + ), + "SELECT mytable.myid, mytable.name, mytable.description " + "FROM mytable WHERE mytable.myid = %s " + "LOCK IN SHARE MODE OF mytable NOWAIT", + ) + + self.assert_compile( + table1.select(table1.c.myid == 7).with_for_update( + read=True, of=[table1.c.myid, table1.c.name], nowait=True + ), + "SELECT mytable.myid, mytable.name, mytable.description " + "FROM mytable WHERE mytable.myid = %s " + "LOCK IN SHARE MODE OF mytable NOWAIT", + ) + + self.assert_compile( + table1.select(table1.c.myid == 7).with_for_update( + read=True, of=table1, nowait=True + ), + "SELECT mytable.myid, mytable.name, mytable.description " + "FROM mytable WHERE mytable.myid = %s " + "LOCK IN SHARE MODE OF mytable NOWAIT", + ) + + ta = table1.alias() + self.assert_compile( + ta.select(ta.c.myid == 7).with_for_update( + of=[ta.c.myid, ta.c.name] + ), + "SELECT mytable_1.myid, mytable_1.name, mytable_1.description " + "FROM mytable AS mytable_1 " + "WHERE mytable_1.myid = %s FOR UPDATE OF mytable_1", + ) + + table2 = table("table2", column("mytable_id")) + join = table2.join(table1, table2.c.mytable_id == table1.c.myid) + self.assert_compile( + join.select(table2.c.mytable_id == 7).with_for_update(of=[join]), + "SELECT table2.mytable_id, " + "mytable.myid, mytable.name, mytable.description " + "FROM table2 " + "INNER JOIN mytable ON table2.mytable_id = mytable.myid " + "WHERE table2.mytable_id = %s " + "FOR UPDATE OF mytable, table2", + ) + + join = table2.join(ta, table2.c.mytable_id == ta.c.myid) + self.assert_compile( + join.select(table2.c.mytable_id == 7).with_for_update(of=[join]), + "SELECT table2.mytable_id, " + "mytable_1.myid, mytable_1.name, mytable_1.description " + "FROM table2 " + "INNER JOIN mytable AS mytable_1 " + "ON table2.mytable_id = mytable_1.myid " + "WHERE table2.mytable_id = %s " + "FOR UPDATE OF mytable_1, table2", + ) + def test_delete_extra_froms(self): t1 = table("t1", column("c1")) t2 = table("t2", column("c1")) -- 2.47.3