From: Mike Bayer Date: Fri, 28 Apr 2023 16:07:09 +0000 (-0400) Subject: implement FromLinter for UPDATE, DELETE statements X-Git-Tag: rel_2_0_13~8^2 X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=4a62625d99470c8928422c4822df5234b93b6bb8;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git implement FromLinter for UPDATE, DELETE statements Implemented the "cartesian product warning" for UPDATE and DELETE statements, those which include multiple tables that are not correlated together in some way. Fixed issue where :func:`_dml.update` construct that included multiple tables and no VALUES clause would raise with an internal error. Current behavior for :class:`_dml.Update` with no values is to generate a SQL UPDATE statement with an empty "set" clause, so this has been made consistent for this specific sub-case. Fixes: #9721 Change-Id: I556639811cc930d2e37532965d2ae751882af921 --- diff --git a/doc/build/changelog/unreleased_20/9721.rst b/doc/build/changelog/unreleased_20/9721.rst new file mode 100644 index 0000000000..2a2b29f84a --- /dev/null +++ b/doc/build/changelog/unreleased_20/9721.rst @@ -0,0 +1,16 @@ +.. change:: + :tags: usecase, sql + :tickets: 9721 + + Implemented the "cartesian product warning" for UPDATE and DELETE + statements, those which include multiple tables that are not correlated + together in some way. + +.. change:: + :tags: bug, sql + + Fixed issue where :func:`_dml.update` construct that included multiple + tables and no VALUES clause would raise with an internal error. Current + behavior for :class:`_dml.Update` with no values is to generate a SQL + UPDATE statement with an empty "set" clause, so this has been made + consistent for this specific sub-case. diff --git a/lib/sqlalchemy/dialects/mssql/base.py b/lib/sqlalchemy/dialects/mssql/base.py index b33ce4aec8..aa319e2393 100644 --- a/lib/sqlalchemy/dialects/mssql/base.py +++ b/lib/sqlalchemy/dialects/mssql/base.py @@ -2425,13 +2425,13 @@ class MSSQLCompiler(compiler.SQLCompiler): for t in [from_table] + extra_froms ) - def delete_table_clause(self, delete_stmt, from_table, extra_froms): + def delete_table_clause(self, delete_stmt, from_table, extra_froms, **kw): """If we have extra froms make sure we render any alias as hint.""" ashint = False if extra_froms: ashint = True return from_table._compiler_dispatch( - self, asfrom=True, iscrud=True, ashint=ashint + self, asfrom=True, iscrud=True, ashint=ashint, **kw ) def delete_extra_from_clause( diff --git a/lib/sqlalchemy/dialects/mysql/base.py b/lib/sqlalchemy/dialects/mysql/base.py index 2ed2bbc7a0..ae40fea99c 100644 --- a/lib/sqlalchemy/dialects/mysql/base.py +++ b/lib/sqlalchemy/dialects/mysql/base.py @@ -1657,13 +1657,13 @@ class MySQLCompiler(compiler.SQLCompiler): ): return None - def delete_table_clause(self, delete_stmt, from_table, extra_froms): + def delete_table_clause(self, delete_stmt, from_table, extra_froms, **kw): """If we have extra froms make sure we render any alias as hint.""" ashint = False if extra_froms: ashint = True return from_table._compiler_dispatch( - self, asfrom=True, iscrud=True, ashint=ashint + self, asfrom=True, iscrud=True, ashint=ashint, **kw ) def delete_extra_from_clause( diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 554a841120..619ff08488 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -710,7 +710,7 @@ class FromLinter(collections.namedtuple("FromLinter", ["froms", "edges"])): else: return None, None - def warn(self): + def warn(self, stmt_type="SELECT"): the_rest, start_with = self.lint() # FROMS left over? boom @@ -719,7 +719,7 @@ class FromLinter(collections.namedtuple("FromLinter", ["froms", "edges"])): froms = the_rest if froms: template = ( - "SELECT statement has a cartesian product between " + "{stmt_type} statement has a cartesian product between " "FROM element(s) {froms} and " 'FROM element "{start}". Apply join condition(s) ' "between each element to resolve." @@ -728,7 +728,9 @@ class FromLinter(collections.namedtuple("FromLinter", ["froms", "edges"])): f'"{self.froms[from_]}"' for from_ in froms ) message = template.format( - froms=froms_str, start=self.froms[start_with] + stmt_type=stmt_type, + froms=froms_str, + start=self.froms[start_with], ) util.warn(message) @@ -5997,6 +5999,7 @@ class SQLCompiler(Compiled): ) def visit_update(self, update_stmt, **kw): + compile_state = update_stmt._compile_state_factory( update_stmt, self, **kw ) @@ -6010,6 +6013,15 @@ class SQLCompiler(Compiled): if not self.compile_state: self.compile_state = compile_state + if self.linting & COLLECT_CARTESIAN_PRODUCTS: + from_linter = FromLinter({}, set()) + warn_linting = self.linting & WARN_LINTING + if toplevel: + self.from_linter = from_linter + else: + from_linter = None + warn_linting = False + extra_froms = compile_state._extra_froms is_multitable = bool(extra_froms) @@ -6040,7 +6052,11 @@ class SQLCompiler(Compiled): ) table_text = self.update_tables_clause( - update_stmt, update_stmt.table, render_extra_froms, **kw + update_stmt, + update_stmt.table, + render_extra_froms, + from_linter=from_linter, + **kw, ) crud_params_struct = crud._get_crud_params( self, update_stmt, compile_state, toplevel, **kw @@ -6081,6 +6097,7 @@ class SQLCompiler(Compiled): update_stmt.table, render_extra_froms, dialect_hints, + from_linter=from_linter, **kw, ) if extra_from_text: @@ -6088,7 +6105,7 @@ class SQLCompiler(Compiled): if update_stmt._where_criteria: t = self._generate_delimited_and_list( - update_stmt._where_criteria, **kw + update_stmt._where_criteria, from_linter=from_linter, **kw ) if t: text += " WHERE " + t @@ -6110,6 +6127,10 @@ class SQLCompiler(Compiled): nesting_level = len(self.stack) if not toplevel else None text = self._render_cte_clause(nesting_level=nesting_level) + text + if warn_linting: + assert from_linter is not None + from_linter.warn(stmt_type="UPDATE") + self.stack.pop(-1) return text @@ -6130,8 +6151,10 @@ class SQLCompiler(Compiled): "criteria within DELETE" ) - def delete_table_clause(self, delete_stmt, from_table, extra_froms): - return from_table._compiler_dispatch(self, asfrom=True, iscrud=True) + def delete_table_clause(self, delete_stmt, from_table, extra_froms, **kw): + return from_table._compiler_dispatch( + self, asfrom=True, iscrud=True, **kw + ) def visit_delete(self, delete_stmt, **kw): compile_state = delete_stmt._compile_state_factory( @@ -6147,6 +6170,15 @@ class SQLCompiler(Compiled): if not self.compile_state: self.compile_state = compile_state + if self.linting & COLLECT_CARTESIAN_PRODUCTS: + from_linter = FromLinter({}, set()) + warn_linting = self.linting & WARN_LINTING + if toplevel: + self.from_linter = from_linter + else: + from_linter = None + warn_linting = False + extra_froms = compile_state._extra_froms correlate_froms = {delete_stmt.table}.union(extra_froms) @@ -6166,9 +6198,22 @@ class SQLCompiler(Compiled): ) text += "FROM " - table_text = self.delete_table_clause( - delete_stmt, delete_stmt.table, extra_froms - ) + + try: + table_text = self.delete_table_clause( + delete_stmt, + delete_stmt.table, + extra_froms, + from_linter=from_linter, + ) + except TypeError: + # anticipate 3rd party dialects that don't include **kw + # TODO: remove in 2.1 + table_text = self.delete_table_clause( + delete_stmt, delete_stmt.table, extra_froms + ) + if from_linter: + _ = self.process(delete_stmt.table, from_linter=from_linter) crud._get_crud_params(self, delete_stmt, compile_state, toplevel, **kw) @@ -6199,6 +6244,7 @@ class SQLCompiler(Compiled): delete_stmt.table, extra_froms, dialect_hints, + from_linter=from_linter, **kw, ) if extra_from_text: @@ -6206,7 +6252,7 @@ class SQLCompiler(Compiled): if delete_stmt._where_criteria: t = self._generate_delimited_and_list( - delete_stmt._where_criteria, **kw + delete_stmt._where_criteria, from_linter=from_linter, **kw ) if t: text += " WHERE " + t @@ -6224,6 +6270,10 @@ class SQLCompiler(Compiled): nesting_level = len(self.stack) if not toplevel else None text = self._render_cte_clause(nesting_level=nesting_level) + text + if warn_linting: + assert from_linter is not None + from_linter.warn(stmt_type="DELETE") + self.stack.pop(-1) return text diff --git a/lib/sqlalchemy/sql/crud.py b/lib/sqlalchemy/sql/crud.py index 563f61c046..16d5ce4941 100644 --- a/lib/sqlalchemy/sql/crud.py +++ b/lib/sqlalchemy/sql/crud.py @@ -1344,7 +1344,7 @@ def _get_update_multitable_params( ): normalized_params = { coercions.expect(roles.DMLColumnRole, c): param - for c, param in stmt_parameter_tuples + for c, param in stmt_parameter_tuples or () } include_table = compile_state.include_table_with_column_exprs diff --git a/test/sql/test_from_linter.py b/test/sql/test_from_linter.py index 49370b1e67..9a471d5712 100644 --- a/test/sql/test_from_linter.py +++ b/test/sql/test_from_linter.py @@ -1,4 +1,5 @@ from sqlalchemy import column +from sqlalchemy import delete from sqlalchemy import func from sqlalchemy import Integer from sqlalchemy import JSON @@ -7,6 +8,7 @@ from sqlalchemy import sql from sqlalchemy import table from sqlalchemy import testing from sqlalchemy import true +from sqlalchemy import update from sqlalchemy.testing import config from sqlalchemy.testing import engines from sqlalchemy.testing import expect_warnings @@ -382,18 +384,54 @@ class TestFindUnmatchingFroms(fixtures.TablesTest): froms, start = find_unmatching_froms(query) assert not froms + @testing.variation("dml", ["update", "delete"]) + @testing.combinations( + (False, False), (True, False), (True, True), argnames="twotable,error" + ) + def test_dml(self, dml, twotable, error): + if dml.update: + stmt = update(self.a) + elif dml.delete: + stmt = delete(self.a) + else: + dml.fail() + + stmt = stmt.where(self.a.c.col_a == "a1") + if twotable: + stmt = stmt.where(self.b.c.col_b == "a1") + + if not error: + stmt = stmt.where(self.b.c.col_b == self.a.c.col_a) + + froms, _ = find_unmatching_froms(stmt) + if error: + assert froms + else: + assert not froms + + +class TestLinterRoundTrip(fixtures.TablesTest): + __backend__ = True -class TestLinter(fixtures.TablesTest): @classmethod def define_tables(cls, metadata): - Table("table_a", metadata, Column("col_a", Integer, primary_key=True)) - Table("table_b", metadata, Column("col_b", Integer, primary_key=True)) + Table( + "table_a", + metadata, + Column("col_a", Integer, primary_key=True, autoincrement=False), + ) + Table( + "table_b", + metadata, + Column("col_b", Integer, primary_key=True, autoincrement=False), + ) @classmethod def setup_bind(cls): # from linting is enabled by default return config.db + @testing.only_on("sqlite") def test_noop_for_unhandled_objects(self): with self.bind.connect() as conn: conn.exec_driver_sql("SELECT 1;").fetchone() @@ -429,6 +467,7 @@ class TestLinter(fixtures.TablesTest): with self.bind.connect() as conn: conn.execute(query) + @testing.requires.ctes def test_warn_anon_cte(self): a, b = self.tables("table_a", "table_b") @@ -444,6 +483,47 @@ class TestLinter(fixtures.TablesTest): with self.bind.connect() as conn: conn.execute(query) + @testing.variation( + "dml", + [ + ("update", testing.requires.update_from), + ("delete", testing.requires.delete_using), + ], + ) + @testing.combinations( + (False, False), (True, False), (True, True), argnames="twotable,error" + ) + def test_warn_dml(self, dml, twotable, error): + a, b = self.tables("table_a", "table_b") + + if dml.update: + stmt = update(a).values(col_a=5) + elif dml.delete: + stmt = delete(a) + else: + dml.fail() + + stmt = stmt.where(a.c.col_a == 1) + if twotable: + stmt = stmt.where(b.c.col_b == 1) + + if not error: + stmt = stmt.where(b.c.col_b == a.c.col_a) + + stmt_type = "UPDATE" if dml.update else "DELETE" + + with self.bind.connect() as conn: + if error: + with expect_warnings( + rf"{stmt_type} statement has a cartesian product between " + rf'FROM element\(s\) "table_[ab]" and FROM ' + rf'element "table_[ab]"' + ): + with self.bind.connect() as conn: + conn.execute(stmt) + else: + conn.execute(stmt) + def test_no_linting(self, metadata, connection): eng = engines.testing_engine( options={"enable_from_linting": False, "use_reaper": False} diff --git a/test/sql/test_update.py b/test/sql/test_update.py index ef8f117bcd..d8de5c277b 100644 --- a/test/sql/test_update.py +++ b/test/sql/test_update.py @@ -113,6 +113,59 @@ class _UpdateFromTestBase: class UpdateTest(_UpdateFromTestBase, fixtures.TablesTest, AssertsCompiledSQL): __dialect__ = "default_enhanced" + @testing.variation("twotable", [True, False]) + @testing.variation("values", ["none", "blank"]) + def test_update_no_params(self, values, twotable): + """test issue identified while doing #9721 + + UPDATE with empty VALUES but multiple tables would raise a + NoneType error; fixed this to emit an empty "SET" the way a single + table UPDATE currently does. + + both cases should probably raise CompileError, however this could + be backwards incompatible with current use cases (such as other test + suites) + + """ + + table1 = self.tables.mytable + table2 = self.tables.myothertable + + stmt = table1.update().where(table1.c.name == "jill") + if twotable: + stmt = stmt.where(table2.c.otherid == table1.c.myid) + + if values.blank: + stmt = stmt.values() + + if twotable: + if values.blank: + self.assert_compile( + stmt, + "UPDATE mytable SET FROM myothertable " + "WHERE mytable.name = :name_1 " + "AND myothertable.otherid = mytable.myid", + ) + elif values.none: + self.assert_compile( + stmt, + "UPDATE mytable SET myid=:myid, name=:name, " + "description=:description FROM myothertable " + "WHERE mytable.name = :name_1 " + "AND myothertable.otherid = mytable.myid", + ) + elif values.blank: + self.assert_compile( + stmt, + "UPDATE mytable SET WHERE mytable.name = :name_1", + ) + elif values.none: + self.assert_compile( + stmt, + "UPDATE mytable SET myid=:myid, name=:name, " + "description=:description WHERE mytable.name = :name_1", + ) + def test_update_literal_binds(self): table1 = self.tables.mytable