from ...engine import reflection
from ...sql import compiler
from ...sql import elements
+from ...sql import util as sql_util
from ...types import BINARY
from ...types import BLOB
from ...types import BOOLEAN
def for_update_clause(self, select, **kw):
if select._for_update_arg.read:
- return " LOCK IN SHARE MODE"
+ tmp = " LOCK IN SHARE MODE"
else:
- return " FOR UPDATE"
+ tmp = " FOR UPDATE"
+
+ if select._for_update_arg.of and self.dialect.supports_for_update_of:
+
+ tables = util.OrderedSet()
+ for c in select._for_update_arg.of:
+ tables.update(sql_util.surface_selectables_only(c))
+
+ tmp += " OF " + ", ".join(
+ self.process(table, ashint=True, use_schema=False, **kw)
+ for table in tables
+ )
+
+ if select._for_update_arg.nowait:
+ tmp += " NOWAIT"
+
+ if select._for_update_arg.skip_locked and self.dialect._is_mysql:
+ tmp += " SKIP LOCKED"
+
+ return tmp
def limit_clause(self, select, **kw):
# MySQL supports:
supports_native_enum = True
+ supports_for_update_of = False # default for MySQL ...
+ # ... may be updated to True for MySQL 8+ in initialize()
+
supports_sane_rowcount = True
supports_sane_multi_rowcount = False
supports_multivalues_insert = True
default.DefaultDialect.initialize(self, connection)
+ self.supports_for_update_of = (
+ self._is_mysql and self.server_version_info >= (8,)
+ )
+
self._needs_correct_for_88718_96365 = (
not self._is_mariadb and self.server_version_info >= (8,)
)
from sqlalchemy import Integer
from sqlalchemy import testing
from sqlalchemy import update
+from sqlalchemy.dialects.mysql import base as mysql
from sqlalchemy.orm import joinedload
from sqlalchemy.orm import relationship
from sqlalchemy.orm import Session
+from sqlalchemy.sql import column
+from sqlalchemy.sql import table
+from sqlalchemy.testing import AssertsCompiledSQL
from sqlalchemy.testing import fixtures
# no subquery, should be locked
self._assert_a_is_locked(True)
self._assert_b_is_locked(True)
+
+
+class MySQLForUpdateCompileTest(fixtures.TestBase, AssertsCompiledSQL):
+ __dialect__ = mysql.dialect()
+
+ table1 = table(
+ "mytable", column("myid"), column("name"), column("description")
+ )
+ table2 = table("table2", column("mytable_id"))
+ join = table2.join(table1, table2.c.mytable_id == table1.c.myid)
+ for_update_of_dialect = mysql.dialect()
+ for_update_of_dialect.server_version_info = (8, 0, 0)
+ for_update_of_dialect.supports_for_update_of = True
+
+ def test_for_update_basic(self):
+ self.assert_compile(
+ self.table1.select(self.table1.c.myid == 7).with_for_update(),
+ "SELECT mytable.myid, mytable.name, mytable.description "
+ "FROM mytable WHERE mytable.myid = %s FOR UPDATE",
+ )
+
+ def test_for_update_read(self):
+ self.assert_compile(
+ self.table1.select(self.table1.c.myid == 7).with_for_update(
+ read=True
+ ),
+ "SELECT mytable.myid, mytable.name, mytable.description "
+ "FROM mytable WHERE mytable.myid = %s LOCK IN SHARE MODE",
+ )
+
+ def test_for_update_skip_locked(self):
+ self.assert_compile(
+ self.table1.select(self.table1.c.myid == 7).with_for_update(
+ skip_locked=True
+ ),
+ "SELECT mytable.myid, mytable.name, mytable.description "
+ "FROM mytable WHERE mytable.myid = %s "
+ "FOR UPDATE SKIP LOCKED",
+ )
+
+ def test_for_update_read_and_skip_locked(self):
+ self.assert_compile(
+ self.table1.select(self.table1.c.myid == 7).with_for_update(
+ read=True, skip_locked=True
+ ),
+ "SELECT mytable.myid, mytable.name, mytable.description "
+ "FROM mytable WHERE mytable.myid = %s "
+ "LOCK IN SHARE MODE SKIP LOCKED",
+ )
+
+ def test_for_update_nowait(self):
+ self.assert_compile(
+ self.table1.select(self.table1.c.myid == 7).with_for_update(
+ nowait=True
+ ),
+ "SELECT mytable.myid, mytable.name, mytable.description "
+ "FROM mytable WHERE mytable.myid = %s "
+ "FOR UPDATE NOWAIT",
+ )
+
+ def test_for_update_read_and_nowait(self):
+ self.assert_compile(
+ self.table1.select(self.table1.c.myid == 7).with_for_update(
+ read=True, nowait=True
+ ),
+ "SELECT mytable.myid, mytable.name, mytable.description "
+ "FROM mytable WHERE mytable.myid = %s "
+ "LOCK IN SHARE MODE NOWAIT",
+ )
+
+ def test_for_update_of_nowait(self):
+ self.assert_compile(
+ self.table1.select(self.table1.c.myid == 7).with_for_update(
+ of=self.table1, nowait=True
+ ),
+ "SELECT mytable.myid, mytable.name, mytable.description "
+ "FROM mytable WHERE mytable.myid = %s "
+ "FOR UPDATE OF mytable NOWAIT",
+ dialect=self.for_update_of_dialect,
+ )
+
+ def test_for_update_of_basic(self):
+ self.assert_compile(
+ self.table1.select(self.table1.c.myid == 7).with_for_update(
+ of=self.table1
+ ),
+ "SELECT mytable.myid, mytable.name, mytable.description "
+ "FROM mytable WHERE mytable.myid = %s "
+ "FOR UPDATE OF mytable",
+ dialect=self.for_update_of_dialect,
+ )
+
+ def test_for_update_of_skip_locked(self):
+ self.assert_compile(
+ self.table1.select(self.table1.c.myid == 7).with_for_update(
+ of=self.table1, skip_locked=True
+ ),
+ "SELECT mytable.myid, mytable.name, mytable.description "
+ "FROM mytable WHERE mytable.myid = %s "
+ "FOR UPDATE OF mytable SKIP LOCKED",
+ dialect=self.for_update_of_dialect,
+ )
+
+ def test_for_update_of_join_one(self):
+ self.assert_compile(
+ self.join.select(self.table2.c.mytable_id == 7).with_for_update(
+ of=[self.join]
+ ),
+ "SELECT table2.mytable_id, "
+ "mytable.myid, mytable.name, mytable.description "
+ "FROM table2 "
+ "INNER JOIN mytable ON table2.mytable_id = mytable.myid "
+ "WHERE table2.mytable_id = %s "
+ "FOR UPDATE OF mytable, table2",
+ dialect=self.for_update_of_dialect,
+ )
+
+ def test_for_update_of_column_list_aliased(self):
+ ta = self.table1.alias()
+ self.assert_compile(
+ ta.select(ta.c.myid == 7).with_for_update(
+ of=[ta.c.myid, ta.c.name]
+ ),
+ "SELECT mytable_1.myid, mytable_1.name, mytable_1.description "
+ "FROM mytable AS mytable_1 "
+ "WHERE mytable_1.myid = %s FOR UPDATE OF mytable_1",
+ dialect=self.for_update_of_dialect,
+ )
+
+ def test_for_update_of_join_aliased(self):
+ ta = self.table1.alias()
+ alias_join = self.table2.join(
+ ta, self.table2.c.mytable_id == ta.c.myid
+ )
+ self.assert_compile(
+ alias_join.select(self.table2.c.mytable_id == 7).with_for_update(
+ of=[alias_join]
+ ),
+ "SELECT table2.mytable_id, "
+ "mytable_1.myid, mytable_1.name, mytable_1.description "
+ "FROM table2 "
+ "INNER JOIN mytable AS mytable_1 "
+ "ON table2.mytable_id = mytable_1.myid "
+ "WHERE table2.mytable_id = %s "
+ "FOR UPDATE OF mytable_1, table2",
+ dialect=self.for_update_of_dialect,
+ )
+
+ def test_for_update_of_read_nowait(self):
+ self.assert_compile(
+ self.table1.select(self.table1.c.myid == 7).with_for_update(
+ read=True, of=self.table1, nowait=True
+ ),
+ "SELECT mytable.myid, mytable.name, mytable.description "
+ "FROM mytable WHERE mytable.myid = %s "
+ "LOCK IN SHARE MODE OF mytable NOWAIT",
+ dialect=self.for_update_of_dialect,
+ )
+
+ def test_for_update_of_read_skip_locked(self):
+ self.assert_compile(
+ self.table1.select(self.table1.c.myid == 7).with_for_update(
+ read=True, of=self.table1, skip_locked=True
+ ),
+ "SELECT mytable.myid, mytable.name, mytable.description "
+ "FROM mytable WHERE mytable.myid = %s "
+ "LOCK IN SHARE MODE OF mytable SKIP LOCKED",
+ dialect=self.for_update_of_dialect,
+ )
+
+ def test_for_update_of_read_nowait_column_list(self):
+ self.assert_compile(
+ self.table1.select(self.table1.c.myid == 7).with_for_update(
+ read=True,
+ of=[self.table1.c.myid, self.table1.c.name],
+ nowait=True,
+ ),
+ "SELECT mytable.myid, mytable.name, mytable.description "
+ "FROM mytable WHERE mytable.myid = %s "
+ "LOCK IN SHARE MODE OF mytable NOWAIT",
+ dialect=self.for_update_of_dialect,
+ )
+
+ def test_for_update_of_read(self):
+ self.assert_compile(
+ self.table1.select(self.table1.c.myid == 7).with_for_update(
+ read=True, of=self.table1
+ ),
+ "SELECT mytable.myid, mytable.name, mytable.description "
+ "FROM mytable WHERE mytable.myid = %s "
+ "LOCK IN SHARE MODE OF mytable",
+ dialect=self.for_update_of_dialect,
+ )
"FROM mytable WHERE mytable.myid = %(myid_1)s FOR SHARE NOWAIT",
)
+ self.assert_compile(
+ table1.select(table1.c.myid == 7).with_for_update(
+ key_share=True, nowait=True
+ ),
+ "SELECT mytable.myid, mytable.name, mytable.description "
+ "FROM mytable WHERE mytable.myid = %(myid_1)s "
+ "FOR NO KEY UPDATE NOWAIT",
+ )
+
+ self.assert_compile(
+ table1.select(table1.c.myid == 7).with_for_update(
+ key_share=True, read=True, nowait=True
+ ),
+ "SELECT mytable.myid, mytable.name, mytable.description "
+ "FROM mytable WHERE mytable.myid = %(myid_1)s "
+ "FOR KEY SHARE NOWAIT",
+ )
+
self.assert_compile(
table1.select(table1.c.myid == 7).with_for_update(
read=True, skip_locked=True
"FOR SHARE OF mytable NOWAIT",
)
+ self.assert_compile(
+ table1.select(table1.c.myid == 7).with_for_update(
+ key_share=True, read=True, nowait=True, of=table1
+ ),
+ "SELECT mytable.myid, mytable.name, mytable.description "
+ "FROM mytable WHERE mytable.myid = %(myid_1)s "
+ "FOR KEY SHARE OF mytable NOWAIT",
+ )
+
self.assert_compile(
table1.select(table1.c.myid == 7).with_for_update(
read=True, nowait=True, of=table1.c.myid
"FOR SHARE OF mytable NOWAIT",
)
+ self.assert_compile(
+ table1.select(table1.c.myid == 7).with_for_update(
+ read=True,
+ skip_locked=True,
+ of=[table1.c.myid, table1.c.name],
+ key_share=True,
+ ),
+ "SELECT mytable.myid, mytable.name, mytable.description "
+ "FROM mytable WHERE mytable.myid = %(myid_1)s "
+ "FOR KEY SHARE OF mytable SKIP LOCKED",
+ )
+
+ self.assert_compile(
+ table1.select(table1.c.myid == 7).with_for_update(
+ skip_locked=True, of=[table1.c.myid, table1.c.name]
+ ),
+ "SELECT mytable.myid, mytable.name, mytable.description "
+ "FROM mytable WHERE mytable.myid = %(myid_1)s "
+ "FOR UPDATE OF mytable SKIP LOCKED",
+ )
+
self.assert_compile(
table1.select(table1.c.myid == 7).with_for_update(
read=True, skip_locked=True, of=[table1.c.myid, table1.c.name]
"FOR KEY SHARE OF mytable",
)
+ self.assert_compile(
+ table1.select(table1.c.myid == 7).with_for_update(
+ read=True, of=table1
+ ),
+ "SELECT mytable.myid, mytable.name, mytable.description "
+ "FROM mytable WHERE mytable.myid = %(myid_1)s "
+ "FOR SHARE OF mytable",
+ )
+
self.assert_compile(
table1.select(table1.c.myid == 7).with_for_update(
read=True, key_share=True, skip_locked=True
"FOR KEY SHARE SKIP LOCKED",
)
+ self.assert_compile(
+ table1.select(table1.c.myid == 7).with_for_update(
+ key_share=True, skip_locked=True
+ ),
+ "SELECT mytable.myid, mytable.name, mytable.description "
+ "FROM mytable WHERE mytable.myid = %(myid_1)s "
+ "FOR NO KEY UPDATE SKIP LOCKED",
+ )
+
ta = table1.alias()
self.assert_compile(
ta.select(ta.c.myid == 7).with_for_update(