]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Scan for tables without relying upon whereclause
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 29 Sep 2020 18:17:42 +0000 (14:17 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 29 Sep 2020 23:59:55 +0000 (19:59 -0400)
Fixed bug where an UPDATE statement against a JOIN using MySQL multi-table
format would fail to include the table prefix for the target table if the
statement had no WHERE clause, as only the WHERE clause were scanned to
detect a "multi table update" at that particular point.  The target
is now also scanned if it's a JOIN to get the leftmost table as the
primary table and the additional entries as additional FROM entries.

Fixes: #5617
Change-Id: I26d74afebe06e28af28acf960258f170a1627823
(cherry picked from commit 7d8c93d3d0eaf39ec7beb1dcb32ee0e11d3f77fa)

doc/build/changelog/unreleased_13/5617.rst [new file with mode: 0644]
lib/sqlalchemy/sql/dml.py
lib/sqlalchemy/sql/util.py
setup.cfg
test/aaa_profiling/test_compiler.py
test/profiles.txt
test/sql/test_update.py

diff --git a/doc/build/changelog/unreleased_13/5617.rst b/doc/build/changelog/unreleased_13/5617.rst
new file mode 100644 (file)
index 0000000..da20787
--- /dev/null
@@ -0,0 +1,11 @@
+.. change::
+    :tags: bug, mysql
+    :tickets: 5617
+
+    Fixed bug where an UPDATE statement against a JOIN using MySQL multi-table
+    format would fail to include the table prefix for the target table if the
+    statement had no WHERE clause, as only the WHERE clause were scanned to
+    detect a "multi table update" at that particular point.  The target
+    is now also scanned if it's a JOIN to get the leftmost table as the
+    primary table and the additional entries as additional FROM entries.
+
index 0362435d2fce6be8afe67a0c4470d48333b56878..b13c942771057e89856414ccfd70b16d4e369909 100644 (file)
@@ -9,7 +9,7 @@ Provide :class:`_expression.Insert`, :class:`_expression.Update` and
 :class:`_expression.Delete`.
 
 """
-
+from . import util as sql_util
 from .base import _from_objects
 from .base import _generative
 from .base import DialectKWArgs
@@ -817,7 +817,9 @@ class Update(ValuesBase):
     @property
     def _extra_froms(self):
         froms = []
-        seen = {self.table}
+
+        all_tables = list(sql_util.tables_from_leftmost(self.table))
+        seen = {all_tables[0]}
 
         if self._whereclause is not None:
             for item in _from_objects(self._whereclause):
@@ -825,6 +827,7 @@ class Update(ValuesBase):
                     froms.append(item)
                 seen.update(item._cloned_set)
 
+        froms.extend(all_tables[1:])
         return froms
 
 
index 29ed6fa7185719af617a9f73f5adedc529dffd75..73a5b84b00fed147cec8512f14acb95d792bd013 100644 (file)
@@ -351,6 +351,19 @@ def clause_is_present(clause, search):
         return False
 
 
+def tables_from_leftmost(clause):
+    if isinstance(clause, Join):
+        for t in tables_from_leftmost(clause.left):
+            yield t
+        for t in tables_from_leftmost(clause.right):
+            yield t
+    elif isinstance(clause, FromGrouping):
+        for t in tables_from_leftmost(clause.element):
+            yield t
+    else:
+        yield clause
+
+
 def surface_selectables(clause):
     stack = [clause]
     while stack:
index ff68592b8f86e7b3e10de1b331cf0c714deac5d5..f15ab5d7b67c1bfe8cd5565732ec696bf8b2e526 100644 (file)
--- a/setup.cfg
+++ b/setup.cfg
@@ -59,7 +59,7 @@ postgresql_psycopg2cffi=postgresql+psycopg2cffi://scott:tiger@127.0.0.1:5432/tes
 mysql=mysql://scott:tiger@127.0.0.1:3306/test?charset=utf8mb4
 pymysql=mysql+pymysql://scott:tiger@127.0.0.1:3306/test?charset=utf8mb4
 
-mssql=mssql+pyodbc://scott:tiger@ms_2008
+mssql=mssql+pyodbc://scott:tiger^5HHH@mssql2017:1433/test?driver=ODBC+Driver+13+for+SQL+Server
 mssql_pymssql=mssql+pymssql://scott:tiger@ms_2008
 
 oracle=oracle://scott:tiger@127.0.0.1:1521
index 623efb9a04be0ed2d0f2cfd11746638cdb799de9..910391d80880d4c5a8065d777b36fc1c3149865f 100644 (file)
@@ -65,7 +65,7 @@ class CompileTest(fixtures.TestBase, AssertsExecutionResults):
     def test_update_whereclause(self):
         t1.update().where(t1.c.c2 == 12).compile(dialect=self.dialect)
 
-        @profiling.function_call_count()
+        @profiling.function_call_count(variance=0.20)
         def go():
             t1.update().where(t1.c.c2 == 12).compile(dialect=self.dialect)
 
index 55cf1300d9fa04b673122ecdd3f238a70cd624db..5c08cd5c66990065c0ad95219e0242890faa3f91 100644 (file)
@@ -113,26 +113,38 @@ test.aaa_profiling.test_compiler.CompileTest.test_update x86_64_linux_cpython_3.
 
 # TEST: test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause
 
-test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_2.7_mssql_pyodbc_dbapiunicode_cextensions 156,156
-test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_2.7_mssql_pyodbc_dbapiunicode_nocextensions 156,156
-test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_2.7_mysql_mysqldb_dbapiunicode_cextensions 158,158,158,158
-test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_2.7_mysql_mysqldb_dbapiunicode_nocextensions 158,158,158,158
-test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_2.7_mysql_pymysql_dbapiunicode_cextensions 158,158
-test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_2.7_mysql_pymysql_dbapiunicode_nocextensions 158,158
-test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_2.7_oracle_cx_oracle_dbapiunicode_cextensions 158,156
-test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_2.7_oracle_cx_oracle_dbapiunicode_nocextensions 158,156
-test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_2.7_postgresql_psycopg2_dbapiunicode_cextensions 158,158,158
-test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_2.7_postgresql_psycopg2_dbapiunicode_nocextensions 158,158,158
-test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_2.7_sqlite_pysqlite_dbapiunicode_cextensions 157,156
-test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_2.7_sqlite_pysqlite_dbapiunicode_nocextensions 157,156
-test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_3.8_mysql_mysqldb_dbapiunicode_cextensions 158,158,158,158
-test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_3.8_mysql_mysqldb_dbapiunicode_nocextensions 158,158,158,158
-test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_3.8_mysql_pymysql_dbapiunicode_cextensions 158,158
-test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_3.8_mysql_pymysql_dbapiunicode_nocextensions 158,158
+test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_2.7_mssql_pyodbc_dbapiunicode_cextensions 162
+test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_2.7_mssql_pyodbc_dbapiunicode_nocextensions 162
+test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_2.7_mysql_mysqldb_dbapiunicode_cextensions 162
+test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_2.7_mysql_mysqldb_dbapiunicode_nocextensions 162
+test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_2.7_mysql_pymysql_dbapiunicode_cextensions 162
+test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_2.7_mysql_pymysql_dbapiunicode_nocextensions 162
+test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_2.7_oracle_cx_oracle_dbapiunicode_cextensions 162
+test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_2.7_oracle_cx_oracle_dbapiunicode_nocextensions 162
+test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_2.7_postgresql_psycopg2_dbapiunicode_cextensions 162
+test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_2.7_postgresql_psycopg2_dbapiunicode_nocextensions 162
+test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_2.7_sqlite_pysqlite_dbapiunicode_cextensions 162
+test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_2.7_sqlite_pysqlite_dbapiunicode_nocextensions 162
+test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_3.7_mssql_pyodbc_dbapiunicode_cextensions 167
+test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_3.7_mssql_pyodbc_dbapiunicode_nocextensions 167
+test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_3.7_mysql_mysqldb_dbapiunicode_cextensions 167
+test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_3.7_mysql_mysqldb_dbapiunicode_nocextensions 167
+test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_3.7_mysql_pymysql_dbapiunicode_cextensions 167
+test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_3.7_mysql_pymysql_dbapiunicode_nocextensions 167
+test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_3.7_oracle_cx_oracle_dbapiunicode_cextensions 167
+test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_3.7_oracle_cx_oracle_dbapiunicode_nocextensions 167
+test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_3.7_postgresql_psycopg2_dbapiunicode_cextensions 167
+test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_3.7_postgresql_psycopg2_dbapiunicode_nocextensions 167
+test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_3.7_sqlite_pysqlite_dbapiunicode_cextensions 167
+test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_3.7_sqlite_pysqlite_dbapiunicode_nocextensions 167
+test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_3.8_mysql_mysqldb_dbapiunicode_cextensions 158
+test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_3.8_mysql_mysqldb_dbapiunicode_nocextensions 158
+test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_3.8_mysql_pymysql_dbapiunicode_cextensions 158
+test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_3.8_mysql_pymysql_dbapiunicode_nocextensions 158
 test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_3.8_oracle_cx_oracle_dbapiunicode_cextensions 158
 test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_3.8_oracle_cx_oracle_dbapiunicode_nocextensions 158
-test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_3.8_postgresql_psycopg2_dbapiunicode_cextensions 158,158,158
-test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_3.8_postgresql_psycopg2_dbapiunicode_nocextensions 158,158,158
+test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_3.8_postgresql_psycopg2_dbapiunicode_cextensions 158
+test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_3.8_postgresql_psycopg2_dbapiunicode_nocextensions 158
 test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_3.8_sqlite_pysqlite_dbapiunicode_cextensions 162
 test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_3.8_sqlite_pysqlite_dbapiunicode_nocextensions 162
 
index bb6b3df763948a2514b2533374cb21040e7c5699..69f31940d8ac08685502045a230fbce307be8b5d 100644 (file)
@@ -791,7 +791,7 @@ class UpdateFromCompileTest(
             dialect="mysql",
         )
 
-    def test_update_from_join_mysql(self):
+    def test_update_from_join_mysql_whereclause(self):
         users, addresses = self.tables.users, self.tables.addresses
 
         j = users.join(addresses)
@@ -809,6 +809,73 @@ class UpdateFromCompileTest(
             dialect=mysql.dialect(),
         )
 
+    def test_update_from_join_mysql_no_whereclause_one(self):
+        users, addresses = self.tables.users, self.tables.addresses
+
+        j = users.join(addresses)
+        self.assert_compile(
+            update(j).values(name="newname"),
+            ""
+            "UPDATE users "
+            "INNER JOIN addresses ON users.id = addresses.user_id "
+            "SET users.name=%s",
+            checkparams={"name": "newname"},
+            dialect=mysql.dialect(),
+        )
+
+    def test_update_from_join_mysql_no_whereclause_two(self):
+        users, addresses = self.tables.users, self.tables.addresses
+
+        j = users.join(addresses)
+        self.assert_compile(
+            update(j).values({users.c.name: addresses.c.email_address}),
+            ""
+            "UPDATE users "
+            "INNER JOIN addresses ON users.id = addresses.user_id "
+            "SET users.name=addresses.email_address",
+            checkparams={},
+            dialect=mysql.dialect(),
+        )
+
+    def test_update_from_join_mysql_no_whereclause_three(self):
+        users, addresses, dingalings = (
+            self.tables.users,
+            self.tables.addresses,
+            self.tables.dingalings,
+        )
+
+        j = users.join(addresses).join(dingalings)
+        self.assert_compile(
+            update(j).values({users.c.name: dingalings.c.id}),
+            ""
+            "UPDATE users "
+            "INNER JOIN addresses ON users.id = addresses.user_id "
+            "INNER JOIN dingalings ON addresses.id = dingalings.address_id "
+            "SET users.name=dingalings.id",
+            checkparams={},
+            dialect=mysql.dialect(),
+        )
+
+    def test_update_from_join_mysql_no_whereclause_four(self):
+        users, addresses, dingalings = (
+            self.tables.users,
+            self.tables.addresses,
+            self.tables.dingalings,
+        )
+
+        j = users.join(addresses).join(dingalings)
+
+        self.assert_compile(
+            update(j).values(name="foo"),
+            ""
+            "UPDATE users "
+            "INNER JOIN addresses ON users.id = addresses.user_id "
+            "INNER JOIN dingalings ON addresses.id = dingalings.address_id "
+            "SET users.name=%s",
+            checkparams={"name": "foo"},
+            dialect=mysql.dialect(),
+        )
+
     def test_render_table(self):
         users, addresses = self.tables.users, self.tables.addresses