]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Scan for tables without relying upon whereclause
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 29 Sep 2020 18:17:42 +0000 (14:17 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 29 Sep 2020 20:46:56 +0000 (16:46 -0400)
Fixed bug where an UPDATE statement against a JOIN using MySQL multi-table
format would fail to include the table prefix for the target table if the
statement had no WHERE clause, as only the WHERE clause were scanned to
detect a "multi table update" at that particular point.  The target
is now also scanned if it's a JOIN to get the leftmost table as the
primary table and the additional entries as additional FROM entries.

Fixes: #5617
Change-Id: I26d74afebe06e28af28acf960258f170a1627823

doc/build/changelog/unreleased_13/5617.rst [new file with mode: 0644]
lib/sqlalchemy/sql/dml.py
lib/sqlalchemy/sql/util.py
test/profiles.txt
test/sql/test_update.py

diff --git a/doc/build/changelog/unreleased_13/5617.rst b/doc/build/changelog/unreleased_13/5617.rst
new file mode 100644 (file)
index 0000000..da20787
--- /dev/null
@@ -0,0 +1,11 @@
+.. change::
+    :tags: bug, mysql
+    :tickets: 5617
+
+    Fixed bug where an UPDATE statement against a JOIN using MySQL multi-table
+    format would fail to include the table prefix for the target table if the
+    statement had no WHERE clause, as only the WHERE clause were scanned to
+    detect a "multi table update" at that particular point.  The target
+    is now also scanned if it's a JOIN to get the leftmost table as the
+    primary table and the additional entries as additional FROM entries.
+
index 5ddc9ef82dbc135283de547d4d14bf89c2f13a91..c923bf651aee4818cfea01f04ac01c0dd5f67d31 100644 (file)
@@ -12,6 +12,7 @@ Provide :class:`_expression.Insert`, :class:`_expression.Update` and
 from sqlalchemy.types import NullType
 from . import coercions
 from . import roles
+from . import util as sql_util
 from .base import _entity_namespace_key
 from .base import _from_objects
 from .base import _generative
@@ -47,7 +48,9 @@ class DMLState(CompileState):
 
     def _make_extra_froms(self, statement):
         froms = []
-        seen = {statement.table}
+
+        all_tables = list(sql_util.tables_from_leftmost(statement.table))
+        seen = {all_tables[0]}
 
         for crit in statement._where_criteria:
             for item in _from_objects(crit):
@@ -55,6 +58,7 @@ class DMLState(CompileState):
                     froms.append(item)
                 seen.update(item._cloned_set)
 
+        froms.extend(all_tables[1:])
         return froms
 
     def _process_multi_values(self, statement):
index 96fa209fd7655bf8b5117e879703d9a899519b9b..e4f7532bed7195cb7157115f788f09ccc7311689 100644 (file)
@@ -366,6 +366,19 @@ def clause_is_present(clause, search):
         return False
 
 
+def tables_from_leftmost(clause):
+    if isinstance(clause, Join):
+        for t in tables_from_leftmost(clause.left):
+            yield t
+        for t in tables_from_leftmost(clause.right):
+            yield t
+    elif isinstance(clause, FromGrouping):
+        for t in tables_from_leftmost(clause.element):
+            yield t
+    else:
+        yield clause
+
+
 def surface_selectables(clause):
     stack = [clause]
     while stack:
index 15f27925686b2d4fca9e78efa16a2552ed411622..4a21de4273b82640955509368e9eb57af11c81f2 100644 (file)
@@ -153,38 +153,38 @@ test.aaa_profiling.test_compiler.CompileTest.test_update x86_64_linux_cpython_3.
 
 # TEST: test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause
 
-test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_2.7_mariadb_mysqldb_dbapiunicode_cextensions 154
-test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_2.7_mariadb_mysqldb_dbapiunicode_nocextensions 154
-test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_2.7_mariadb_pymysql_dbapiunicode_cextensions 154
-test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_2.7_mariadb_pymysql_dbapiunicode_nocextensions 154
-test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_2.7_mssql_pyodbc_dbapiunicode_cextensions 154
-test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_2.7_mssql_pyodbc_dbapiunicode_nocextensions 154
+test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_2.7_mariadb_mysqldb_dbapiunicode_cextensions 159
+test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_2.7_mariadb_mysqldb_dbapiunicode_nocextensions 159
+test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_2.7_mariadb_pymysql_dbapiunicode_cextensions 159
+test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_2.7_mariadb_pymysql_dbapiunicode_nocextensions 159
+test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_2.7_mssql_pyodbc_dbapiunicode_cextensions 159
+test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_2.7_mssql_pyodbc_dbapiunicode_nocextensions 159
 test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_2.7_mysql_mysqldb_dbapiunicode_cextensions 152
 test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_2.7_mysql_mysqldb_dbapiunicode_nocextensions 152
 test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_2.7_mysql_pymysql_dbapiunicode_cextensions 152
 test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_2.7_mysql_pymysql_dbapiunicode_nocextensions 152
-test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_2.7_oracle_cx_oracle_dbapiunicode_cextensions 150
-test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_2.7_oracle_cx_oracle_dbapiunicode_nocextensions 150
-test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_2.7_postgresql_psycopg2_dbapiunicode_cextensions 154
-test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_2.7_postgresql_psycopg2_dbapiunicode_nocextensions 154
-test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_2.7_sqlite_pysqlite_dbapiunicode_cextensions 154
-test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_2.7_sqlite_pysqlite_dbapiunicode_nocextensions 154
-test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_3.8_mariadb_mysqldb_dbapiunicode_cextensions 160
-test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_3.8_mariadb_mysqldb_dbapiunicode_nocextensions 158
-test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_3.8_mariadb_pymysql_dbapiunicode_cextensions 160
-test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_3.8_mariadb_pymysql_dbapiunicode_nocextensions 158
-test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_3.8_mssql_pyodbc_dbapiunicode_cextensions 160
-test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_3.8_mssql_pyodbc_dbapiunicode_nocextensions 158
+test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_2.7_oracle_cx_oracle_dbapiunicode_cextensions 157
+test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_2.7_oracle_cx_oracle_dbapiunicode_nocextensions 159
+test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_2.7_postgresql_psycopg2_dbapiunicode_cextensions 159
+test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_2.7_postgresql_psycopg2_dbapiunicode_nocextensions 159
+test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_2.7_sqlite_pysqlite_dbapiunicode_cextensions 159
+test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_2.7_sqlite_pysqlite_dbapiunicode_nocextensions 159
+test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_3.8_mariadb_mysqldb_dbapiunicode_cextensions 165
+test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_3.8_mariadb_mysqldb_dbapiunicode_nocextensions 165
+test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_3.8_mariadb_pymysql_dbapiunicode_cextensions 165
+test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_3.8_mariadb_pymysql_dbapiunicode_nocextensions 165
+test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_3.8_mssql_pyodbc_dbapiunicode_cextensions 165
+test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_3.8_mssql_pyodbc_dbapiunicode_nocextensions 165
 test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_3.8_mysql_mysqldb_dbapiunicode_cextensions 158
 test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_3.8_mysql_mysqldb_dbapiunicode_nocextensions 158
 test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_3.8_mysql_pymysql_dbapiunicode_cextensions 158
 test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_3.8_mysql_pymysql_dbapiunicode_nocextensions 158
-test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_3.8_oracle_cx_oracle_dbapiunicode_cextensions 156
-test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_3.8_oracle_cx_oracle_dbapiunicode_nocextensions 156
-test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_3.8_postgresql_psycopg2_dbapiunicode_cextensions 160
-test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_3.8_postgresql_psycopg2_dbapiunicode_nocextensions 160
-test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_3.8_sqlite_pysqlite_dbapiunicode_cextensions 160
-test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_3.8_sqlite_pysqlite_dbapiunicode_nocextensions 160
+test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_3.8_oracle_cx_oracle_dbapiunicode_cextensions 163
+test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_3.8_oracle_cx_oracle_dbapiunicode_nocextensions 163
+test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_3.8_postgresql_psycopg2_dbapiunicode_cextensions 165
+test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_3.8_postgresql_psycopg2_dbapiunicode_nocextensions 165
+test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_3.8_sqlite_pysqlite_dbapiunicode_cextensions 165
+test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_3.8_sqlite_pysqlite_dbapiunicode_nocextensions 165
 
 # TEST: test.aaa_profiling.test_misc.CacheKeyTest.test_statement_key_is_cached
 
index 201e6c64fe146889f15817db42800ca45c47377f..ec96af207e7f1240f4600c888d5678c9464021d6 100644 (file)
@@ -1016,7 +1016,7 @@ class UpdateFromCompileTest(
             dialect="mysql",
         )
 
-    def test_update_from_join_mysql(self):
+    def test_update_from_join_mysql_whereclause(self):
         users, addresses = self.tables.users, self.tables.addresses
 
         j = users.join(addresses)
@@ -1034,6 +1034,73 @@ class UpdateFromCompileTest(
             dialect=mysql.dialect(),
         )
 
+    def test_update_from_join_mysql_no_whereclause_one(self):
+        users, addresses = self.tables.users, self.tables.addresses
+
+        j = users.join(addresses)
+        self.assert_compile(
+            update(j).values(name="newname"),
+            ""
+            "UPDATE users "
+            "INNER JOIN addresses ON users.id = addresses.user_id "
+            "SET users.name=%s",
+            checkparams={"name": "newname"},
+            dialect=mysql.dialect(),
+        )
+
+    def test_update_from_join_mysql_no_whereclause_two(self):
+        users, addresses = self.tables.users, self.tables.addresses
+
+        j = users.join(addresses)
+        self.assert_compile(
+            update(j).values({users.c.name: addresses.c.email_address}),
+            ""
+            "UPDATE users "
+            "INNER JOIN addresses ON users.id = addresses.user_id "
+            "SET users.name=addresses.email_address",
+            checkparams={},
+            dialect=mysql.dialect(),
+        )
+
+    def test_update_from_join_mysql_no_whereclause_three(self):
+        users, addresses, dingalings = (
+            self.tables.users,
+            self.tables.addresses,
+            self.tables.dingalings,
+        )
+
+        j = users.join(addresses).join(dingalings)
+        self.assert_compile(
+            update(j).values({users.c.name: dingalings.c.id}),
+            ""
+            "UPDATE users "
+            "INNER JOIN addresses ON users.id = addresses.user_id "
+            "INNER JOIN dingalings ON addresses.id = dingalings.address_id "
+            "SET users.name=dingalings.id",
+            checkparams={},
+            dialect=mysql.dialect(),
+        )
+
+    def test_update_from_join_mysql_no_whereclause_four(self):
+        users, addresses, dingalings = (
+            self.tables.users,
+            self.tables.addresses,
+            self.tables.dingalings,
+        )
+
+        j = users.join(addresses).join(dingalings)
+
+        self.assert_compile(
+            update(j).values(name="foo"),
+            ""
+            "UPDATE users "
+            "INNER JOIN addresses ON users.id = addresses.user_id "
+            "INNER JOIN dingalings ON addresses.id = dingalings.address_id "
+            "SET users.name=%s",
+            checkparams={"name": "foo"},
+            dialect=mysql.dialect(),
+        )
+
     def test_render_table(self):
         users, addresses = self.tables.users, self.tables.addresses