From: Mike Bayer Date: Tue, 29 Sep 2020 18:17:42 +0000 (-0400) Subject: Scan for tables without relying upon whereclause X-Git-Tag: rel_1_3_20~12 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=f4cc32f66269c98b379920334180514df921397a;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Scan for tables without relying upon whereclause Fixed bug where an UPDATE statement against a JOIN using MySQL multi-table format would fail to include the table prefix for the target table if the statement had no WHERE clause, as only the WHERE clause were scanned to detect a "multi table update" at that particular point. The target is now also scanned if it's a JOIN to get the leftmost table as the primary table and the additional entries as additional FROM entries. Fixes: #5617 Change-Id: I26d74afebe06e28af28acf960258f170a1627823 (cherry picked from commit 7d8c93d3d0eaf39ec7beb1dcb32ee0e11d3f77fa) --- diff --git a/doc/build/changelog/unreleased_13/5617.rst b/doc/build/changelog/unreleased_13/5617.rst new file mode 100644 index 0000000000..da20787ca8 --- /dev/null +++ b/doc/build/changelog/unreleased_13/5617.rst @@ -0,0 +1,11 @@ +.. change:: + :tags: bug, mysql + :tickets: 5617 + + Fixed bug where an UPDATE statement against a JOIN using MySQL multi-table + format would fail to include the table prefix for the target table if the + statement had no WHERE clause, as only the WHERE clause were scanned to + detect a "multi table update" at that particular point. The target + is now also scanned if it's a JOIN to get the leftmost table as the + primary table and the additional entries as additional FROM entries. + diff --git a/lib/sqlalchemy/sql/dml.py b/lib/sqlalchemy/sql/dml.py index 0362435d2f..b13c942771 100644 --- a/lib/sqlalchemy/sql/dml.py +++ b/lib/sqlalchemy/sql/dml.py @@ -9,7 +9,7 @@ Provide :class:`_expression.Insert`, :class:`_expression.Update` and :class:`_expression.Delete`. """ - +from . import util as sql_util from .base import _from_objects from .base import _generative from .base import DialectKWArgs @@ -817,7 +817,9 @@ class Update(ValuesBase): @property def _extra_froms(self): froms = [] - seen = {self.table} + + all_tables = list(sql_util.tables_from_leftmost(self.table)) + seen = {all_tables[0]} if self._whereclause is not None: for item in _from_objects(self._whereclause): @@ -825,6 +827,7 @@ class Update(ValuesBase): froms.append(item) seen.update(item._cloned_set) + froms.extend(all_tables[1:]) return froms diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py index 29ed6fa718..73a5b84b00 100644 --- a/lib/sqlalchemy/sql/util.py +++ b/lib/sqlalchemy/sql/util.py @@ -351,6 +351,19 @@ def clause_is_present(clause, search): return False +def tables_from_leftmost(clause): + if isinstance(clause, Join): + for t in tables_from_leftmost(clause.left): + yield t + for t in tables_from_leftmost(clause.right): + yield t + elif isinstance(clause, FromGrouping): + for t in tables_from_leftmost(clause.element): + yield t + else: + yield clause + + def surface_selectables(clause): stack = [clause] while stack: diff --git a/setup.cfg b/setup.cfg index ff68592b8f..f15ab5d7b6 100644 --- a/setup.cfg +++ b/setup.cfg @@ -59,7 +59,7 @@ postgresql_psycopg2cffi=postgresql+psycopg2cffi://scott:tiger@127.0.0.1:5432/tes mysql=mysql://scott:tiger@127.0.0.1:3306/test?charset=utf8mb4 pymysql=mysql+pymysql://scott:tiger@127.0.0.1:3306/test?charset=utf8mb4 -mssql=mssql+pyodbc://scott:tiger@ms_2008 +mssql=mssql+pyodbc://scott:tiger^5HHH@mssql2017:1433/test?driver=ODBC+Driver+13+for+SQL+Server mssql_pymssql=mssql+pymssql://scott:tiger@ms_2008 oracle=oracle://scott:tiger@127.0.0.1:1521 diff --git a/test/aaa_profiling/test_compiler.py b/test/aaa_profiling/test_compiler.py index 623efb9a04..910391d808 100644 --- a/test/aaa_profiling/test_compiler.py +++ b/test/aaa_profiling/test_compiler.py @@ -65,7 +65,7 @@ class CompileTest(fixtures.TestBase, AssertsExecutionResults): def test_update_whereclause(self): t1.update().where(t1.c.c2 == 12).compile(dialect=self.dialect) - @profiling.function_call_count() + @profiling.function_call_count(variance=0.20) def go(): t1.update().where(t1.c.c2 == 12).compile(dialect=self.dialect) diff --git a/test/profiles.txt b/test/profiles.txt index 55cf1300d9..5c08cd5c66 100644 --- a/test/profiles.txt +++ b/test/profiles.txt @@ -113,26 +113,38 @@ test.aaa_profiling.test_compiler.CompileTest.test_update x86_64_linux_cpython_3. # TEST: test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause -test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_2.7_mssql_pyodbc_dbapiunicode_cextensions 156,156 -test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_2.7_mssql_pyodbc_dbapiunicode_nocextensions 156,156 -test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_2.7_mysql_mysqldb_dbapiunicode_cextensions 158,158,158,158 -test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_2.7_mysql_mysqldb_dbapiunicode_nocextensions 158,158,158,158 -test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_2.7_mysql_pymysql_dbapiunicode_cextensions 158,158 -test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_2.7_mysql_pymysql_dbapiunicode_nocextensions 158,158 -test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_2.7_oracle_cx_oracle_dbapiunicode_cextensions 158,156 -test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_2.7_oracle_cx_oracle_dbapiunicode_nocextensions 158,156 -test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_2.7_postgresql_psycopg2_dbapiunicode_cextensions 158,158,158 -test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_2.7_postgresql_psycopg2_dbapiunicode_nocextensions 158,158,158 -test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_2.7_sqlite_pysqlite_dbapiunicode_cextensions 157,156 -test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_2.7_sqlite_pysqlite_dbapiunicode_nocextensions 157,156 -test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_3.8_mysql_mysqldb_dbapiunicode_cextensions 158,158,158,158 -test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_3.8_mysql_mysqldb_dbapiunicode_nocextensions 158,158,158,158 -test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_3.8_mysql_pymysql_dbapiunicode_cextensions 158,158 -test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_3.8_mysql_pymysql_dbapiunicode_nocextensions 158,158 +test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_2.7_mssql_pyodbc_dbapiunicode_cextensions 162 +test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_2.7_mssql_pyodbc_dbapiunicode_nocextensions 162 +test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_2.7_mysql_mysqldb_dbapiunicode_cextensions 162 +test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_2.7_mysql_mysqldb_dbapiunicode_nocextensions 162 +test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_2.7_mysql_pymysql_dbapiunicode_cextensions 162 +test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_2.7_mysql_pymysql_dbapiunicode_nocextensions 162 +test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_2.7_oracle_cx_oracle_dbapiunicode_cextensions 162 +test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_2.7_oracle_cx_oracle_dbapiunicode_nocextensions 162 +test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_2.7_postgresql_psycopg2_dbapiunicode_cextensions 162 +test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_2.7_postgresql_psycopg2_dbapiunicode_nocextensions 162 +test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_2.7_sqlite_pysqlite_dbapiunicode_cextensions 162 +test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_2.7_sqlite_pysqlite_dbapiunicode_nocextensions 162 +test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_3.7_mssql_pyodbc_dbapiunicode_cextensions 167 +test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_3.7_mssql_pyodbc_dbapiunicode_nocextensions 167 +test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_3.7_mysql_mysqldb_dbapiunicode_cextensions 167 +test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_3.7_mysql_mysqldb_dbapiunicode_nocextensions 167 +test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_3.7_mysql_pymysql_dbapiunicode_cextensions 167 +test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_3.7_mysql_pymysql_dbapiunicode_nocextensions 167 +test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_3.7_oracle_cx_oracle_dbapiunicode_cextensions 167 +test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_3.7_oracle_cx_oracle_dbapiunicode_nocextensions 167 +test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_3.7_postgresql_psycopg2_dbapiunicode_cextensions 167 +test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_3.7_postgresql_psycopg2_dbapiunicode_nocextensions 167 +test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_3.7_sqlite_pysqlite_dbapiunicode_cextensions 167 +test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_3.7_sqlite_pysqlite_dbapiunicode_nocextensions 167 +test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_3.8_mysql_mysqldb_dbapiunicode_cextensions 158 +test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_3.8_mysql_mysqldb_dbapiunicode_nocextensions 158 +test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_3.8_mysql_pymysql_dbapiunicode_cextensions 158 +test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_3.8_mysql_pymysql_dbapiunicode_nocextensions 158 test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_3.8_oracle_cx_oracle_dbapiunicode_cextensions 158 test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_3.8_oracle_cx_oracle_dbapiunicode_nocextensions 158 -test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_3.8_postgresql_psycopg2_dbapiunicode_cextensions 158,158,158 -test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_3.8_postgresql_psycopg2_dbapiunicode_nocextensions 158,158,158 +test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_3.8_postgresql_psycopg2_dbapiunicode_cextensions 158 +test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_3.8_postgresql_psycopg2_dbapiunicode_nocextensions 158 test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_3.8_sqlite_pysqlite_dbapiunicode_cextensions 162 test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_3.8_sqlite_pysqlite_dbapiunicode_nocextensions 162 diff --git a/test/sql/test_update.py b/test/sql/test_update.py index bb6b3df763..69f31940d8 100644 --- a/test/sql/test_update.py +++ b/test/sql/test_update.py @@ -791,7 +791,7 @@ class UpdateFromCompileTest( dialect="mysql", ) - def test_update_from_join_mysql(self): + def test_update_from_join_mysql_whereclause(self): users, addresses = self.tables.users, self.tables.addresses j = users.join(addresses) @@ -809,6 +809,73 @@ class UpdateFromCompileTest( dialect=mysql.dialect(), ) + def test_update_from_join_mysql_no_whereclause_one(self): + users, addresses = self.tables.users, self.tables.addresses + + j = users.join(addresses) + self.assert_compile( + update(j).values(name="newname"), + "" + "UPDATE users " + "INNER JOIN addresses ON users.id = addresses.user_id " + "SET users.name=%s", + checkparams={"name": "newname"}, + dialect=mysql.dialect(), + ) + + def test_update_from_join_mysql_no_whereclause_two(self): + users, addresses = self.tables.users, self.tables.addresses + + j = users.join(addresses) + self.assert_compile( + update(j).values({users.c.name: addresses.c.email_address}), + "" + "UPDATE users " + "INNER JOIN addresses ON users.id = addresses.user_id " + "SET users.name=addresses.email_address", + checkparams={}, + dialect=mysql.dialect(), + ) + + def test_update_from_join_mysql_no_whereclause_three(self): + users, addresses, dingalings = ( + self.tables.users, + self.tables.addresses, + self.tables.dingalings, + ) + + j = users.join(addresses).join(dingalings) + self.assert_compile( + update(j).values({users.c.name: dingalings.c.id}), + "" + "UPDATE users " + "INNER JOIN addresses ON users.id = addresses.user_id " + "INNER JOIN dingalings ON addresses.id = dingalings.address_id " + "SET users.name=dingalings.id", + checkparams={}, + dialect=mysql.dialect(), + ) + + def test_update_from_join_mysql_no_whereclause_four(self): + users, addresses, dingalings = ( + self.tables.users, + self.tables.addresses, + self.tables.dingalings, + ) + + j = users.join(addresses).join(dingalings) + + self.assert_compile( + update(j).values(name="foo"), + "" + "UPDATE users " + "INNER JOIN addresses ON users.id = addresses.user_id " + "INNER JOIN dingalings ON addresses.id = dingalings.address_id " + "SET users.name=%s", + checkparams={"name": "foo"}, + dialect=mysql.dialect(), + ) + def test_render_table(self): users, addresses = self.tables.users, self.tables.addresses