]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
implement FromLinter for UPDATE, DELETE statements
authorMike Bayer <mike_mp@zzzcomputing.com>
Fri, 28 Apr 2023 16:07:09 +0000 (12:07 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 9 May 2023 14:08:52 +0000 (10:08 -0400)
Implemented the "cartesian product warning" for UPDATE and DELETE
statements, those which include multiple tables that are not correlated
together in some way.

Fixed issue where :func:`_dml.update` construct that included multiple
tables and no VALUES clause would raise with an internal error. Current
behavior for :class:`_dml.Update` with no values is to generate a SQL
UPDATE statement with an empty "set" clause, so this has been made
consistent for this specific sub-case.

Fixes: #9721
Change-Id: I556639811cc930d2e37532965d2ae751882af921

doc/build/changelog/unreleased_20/9721.rst [new file with mode: 0644]
lib/sqlalchemy/dialects/mssql/base.py
lib/sqlalchemy/dialects/mysql/base.py
lib/sqlalchemy/sql/compiler.py
lib/sqlalchemy/sql/crud.py
test/sql/test_from_linter.py
test/sql/test_update.py

diff --git a/doc/build/changelog/unreleased_20/9721.rst b/doc/build/changelog/unreleased_20/9721.rst
new file mode 100644 (file)
index 0000000..2a2b29f
--- /dev/null
@@ -0,0 +1,16 @@
+.. change::
+    :tags: usecase, sql
+    :tickets: 9721
+
+    Implemented the "cartesian product warning" for UPDATE and DELETE
+    statements, those which include multiple tables that are not correlated
+    together in some way.
+
+.. change::
+    :tags: bug, sql
+
+    Fixed issue where :func:`_dml.update` construct that included multiple
+    tables and no VALUES clause would raise with an internal error. Current
+    behavior for :class:`_dml.Update` with no values is to generate a SQL
+    UPDATE statement with an empty "set" clause, so this has been made 
+    consistent for this specific sub-case.
index b33ce4aec8e6618fa9e12ed6af71d5991a14b80a..aa319e2393cd710e573ea944c4e5b417f5c7d989 100644 (file)
@@ -2425,13 +2425,13 @@ class MSSQLCompiler(compiler.SQLCompiler):
             for t in [from_table] + extra_froms
         )
 
-    def delete_table_clause(self, delete_stmt, from_table, extra_froms):
+    def delete_table_clause(self, delete_stmt, from_table, extra_froms, **kw):
         """If we have extra froms make sure we render any alias as hint."""
         ashint = False
         if extra_froms:
             ashint = True
         return from_table._compiler_dispatch(
-            self, asfrom=True, iscrud=True, ashint=ashint
+            self, asfrom=True, iscrud=True, ashint=ashint, **kw
         )
 
     def delete_extra_from_clause(
index 2ed2bbc7a0e36c6d23e9ef166f169a08fb1d6091..ae40fea99cdfb8df009f2911105b509f650f7c94 100644 (file)
@@ -1657,13 +1657,13 @@ class MySQLCompiler(compiler.SQLCompiler):
     ):
         return None
 
-    def delete_table_clause(self, delete_stmt, from_table, extra_froms):
+    def delete_table_clause(self, delete_stmt, from_table, extra_froms, **kw):
         """If we have extra froms make sure we render any alias as hint."""
         ashint = False
         if extra_froms:
             ashint = True
         return from_table._compiler_dispatch(
-            self, asfrom=True, iscrud=True, ashint=ashint
+            self, asfrom=True, iscrud=True, ashint=ashint, **kw
         )
 
     def delete_extra_from_clause(
index 554a84112060884ba21bb73121d5fdd8437a33df..619ff08488a7e97cd209912f31c300a60e56df76 100644 (file)
@@ -710,7 +710,7 @@ class FromLinter(collections.namedtuple("FromLinter", ["froms", "edges"])):
         else:
             return None, None
 
-    def warn(self):
+    def warn(self, stmt_type="SELECT"):
         the_rest, start_with = self.lint()
 
         # FROMS left over?  boom
@@ -719,7 +719,7 @@ class FromLinter(collections.namedtuple("FromLinter", ["froms", "edges"])):
             froms = the_rest
             if froms:
                 template = (
-                    "SELECT statement has a cartesian product between "
+                    "{stmt_type} statement has a cartesian product between "
                     "FROM element(s) {froms} and "
                     'FROM element "{start}".  Apply join condition(s) '
                     "between each element to resolve."
@@ -728,7 +728,9 @@ class FromLinter(collections.namedtuple("FromLinter", ["froms", "edges"])):
                     f'"{self.froms[from_]}"' for from_ in froms
                 )
                 message = template.format(
-                    froms=froms_str, start=self.froms[start_with]
+                    stmt_type=stmt_type,
+                    froms=froms_str,
+                    start=self.froms[start_with],
                 )
 
                 util.warn(message)
@@ -5997,6 +5999,7 @@ class SQLCompiler(Compiled):
         )
 
     def visit_update(self, update_stmt, **kw):
+
         compile_state = update_stmt._compile_state_factory(
             update_stmt, self, **kw
         )
@@ -6010,6 +6013,15 @@ class SQLCompiler(Compiled):
             if not self.compile_state:
                 self.compile_state = compile_state
 
+        if self.linting & COLLECT_CARTESIAN_PRODUCTS:
+            from_linter = FromLinter({}, set())
+            warn_linting = self.linting & WARN_LINTING
+            if toplevel:
+                self.from_linter = from_linter
+        else:
+            from_linter = None
+            warn_linting = False
+
         extra_froms = compile_state._extra_froms
         is_multitable = bool(extra_froms)
 
@@ -6040,7 +6052,11 @@ class SQLCompiler(Compiled):
             )
 
         table_text = self.update_tables_clause(
-            update_stmt, update_stmt.table, render_extra_froms, **kw
+            update_stmt,
+            update_stmt.table,
+            render_extra_froms,
+            from_linter=from_linter,
+            **kw,
         )
         crud_params_struct = crud._get_crud_params(
             self, update_stmt, compile_state, toplevel, **kw
@@ -6081,6 +6097,7 @@ class SQLCompiler(Compiled):
                 update_stmt.table,
                 render_extra_froms,
                 dialect_hints,
+                from_linter=from_linter,
                 **kw,
             )
             if extra_from_text:
@@ -6088,7 +6105,7 @@ class SQLCompiler(Compiled):
 
         if update_stmt._where_criteria:
             t = self._generate_delimited_and_list(
-                update_stmt._where_criteria, **kw
+                update_stmt._where_criteria, from_linter=from_linter, **kw
             )
             if t:
                 text += " WHERE " + t
@@ -6110,6 +6127,10 @@ class SQLCompiler(Compiled):
             nesting_level = len(self.stack) if not toplevel else None
             text = self._render_cte_clause(nesting_level=nesting_level) + text
 
+        if warn_linting:
+            assert from_linter is not None
+            from_linter.warn(stmt_type="UPDATE")
+
         self.stack.pop(-1)
 
         return text
@@ -6130,8 +6151,10 @@ class SQLCompiler(Compiled):
             "criteria within DELETE"
         )
 
-    def delete_table_clause(self, delete_stmt, from_table, extra_froms):
-        return from_table._compiler_dispatch(self, asfrom=True, iscrud=True)
+    def delete_table_clause(self, delete_stmt, from_table, extra_froms, **kw):
+        return from_table._compiler_dispatch(
+            self, asfrom=True, iscrud=True, **kw
+        )
 
     def visit_delete(self, delete_stmt, **kw):
         compile_state = delete_stmt._compile_state_factory(
@@ -6147,6 +6170,15 @@ class SQLCompiler(Compiled):
             if not self.compile_state:
                 self.compile_state = compile_state
 
+        if self.linting & COLLECT_CARTESIAN_PRODUCTS:
+            from_linter = FromLinter({}, set())
+            warn_linting = self.linting & WARN_LINTING
+            if toplevel:
+                self.from_linter = from_linter
+        else:
+            from_linter = None
+            warn_linting = False
+
         extra_froms = compile_state._extra_froms
 
         correlate_froms = {delete_stmt.table}.union(extra_froms)
@@ -6166,9 +6198,22 @@ class SQLCompiler(Compiled):
             )
 
         text += "FROM "
-        table_text = self.delete_table_clause(
-            delete_stmt, delete_stmt.table, extra_froms
-        )
+
+        try:
+            table_text = self.delete_table_clause(
+                delete_stmt,
+                delete_stmt.table,
+                extra_froms,
+                from_linter=from_linter,
+            )
+        except TypeError:
+            # anticipate 3rd party dialects that don't include **kw
+            # TODO: remove in 2.1
+            table_text = self.delete_table_clause(
+                delete_stmt, delete_stmt.table, extra_froms
+            )
+            if from_linter:
+                _ = self.process(delete_stmt.table, from_linter=from_linter)
 
         crud._get_crud_params(self, delete_stmt, compile_state, toplevel, **kw)
 
@@ -6199,6 +6244,7 @@ class SQLCompiler(Compiled):
                 delete_stmt.table,
                 extra_froms,
                 dialect_hints,
+                from_linter=from_linter,
                 **kw,
             )
             if extra_from_text:
@@ -6206,7 +6252,7 @@ class SQLCompiler(Compiled):
 
         if delete_stmt._where_criteria:
             t = self._generate_delimited_and_list(
-                delete_stmt._where_criteria, **kw
+                delete_stmt._where_criteria, from_linter=from_linter, **kw
             )
             if t:
                 text += " WHERE " + t
@@ -6224,6 +6270,10 @@ class SQLCompiler(Compiled):
             nesting_level = len(self.stack) if not toplevel else None
             text = self._render_cte_clause(nesting_level=nesting_level) + text
 
+        if warn_linting:
+            assert from_linter is not None
+            from_linter.warn(stmt_type="DELETE")
+
         self.stack.pop(-1)
 
         return text
index 563f61c046da643d57e69af3a7b5174d4b823680..16d5ce494116d9b8464cf0047839b50c66812a96 100644 (file)
@@ -1344,7 +1344,7 @@ def _get_update_multitable_params(
 ):
     normalized_params = {
         coercions.expect(roles.DMLColumnRole, c): param
-        for c, param in stmt_parameter_tuples
+        for c, param in stmt_parameter_tuples or ()
     }
 
     include_table = compile_state.include_table_with_column_exprs
index 49370b1e67e3c0cedef56fc07cceafcc7df74e19..9a471d57126477f50f16909c1d970ab89b908054 100644 (file)
@@ -1,4 +1,5 @@
 from sqlalchemy import column
+from sqlalchemy import delete
 from sqlalchemy import func
 from sqlalchemy import Integer
 from sqlalchemy import JSON
@@ -7,6 +8,7 @@ from sqlalchemy import sql
 from sqlalchemy import table
 from sqlalchemy import testing
 from sqlalchemy import true
+from sqlalchemy import update
 from sqlalchemy.testing import config
 from sqlalchemy.testing import engines
 from sqlalchemy.testing import expect_warnings
@@ -382,18 +384,54 @@ class TestFindUnmatchingFroms(fixtures.TablesTest):
         froms, start = find_unmatching_froms(query)
         assert not froms
 
+    @testing.variation("dml", ["update", "delete"])
+    @testing.combinations(
+        (False, False), (True, False), (True, True), argnames="twotable,error"
+    )
+    def test_dml(self, dml, twotable, error):
+        if dml.update:
+            stmt = update(self.a)
+        elif dml.delete:
+            stmt = delete(self.a)
+        else:
+            dml.fail()
+
+        stmt = stmt.where(self.a.c.col_a == "a1")
+        if twotable:
+            stmt = stmt.where(self.b.c.col_b == "a1")
+
+            if not error:
+                stmt = stmt.where(self.b.c.col_b == self.a.c.col_a)
+
+        froms, _ = find_unmatching_froms(stmt)
+        if error:
+            assert froms
+        else:
+            assert not froms
+
+
+class TestLinterRoundTrip(fixtures.TablesTest):
+    __backend__ = True
 
-class TestLinter(fixtures.TablesTest):
     @classmethod
     def define_tables(cls, metadata):
-        Table("table_a", metadata, Column("col_a", Integer, primary_key=True))
-        Table("table_b", metadata, Column("col_b", Integer, primary_key=True))
+        Table(
+            "table_a",
+            metadata,
+            Column("col_a", Integer, primary_key=True, autoincrement=False),
+        )
+        Table(
+            "table_b",
+            metadata,
+            Column("col_b", Integer, primary_key=True, autoincrement=False),
+        )
 
     @classmethod
     def setup_bind(cls):
         # from linting is enabled by default
         return config.db
 
+    @testing.only_on("sqlite")
     def test_noop_for_unhandled_objects(self):
         with self.bind.connect() as conn:
             conn.exec_driver_sql("SELECT 1;").fetchone()
@@ -429,6 +467,7 @@ class TestLinter(fixtures.TablesTest):
             with self.bind.connect() as conn:
                 conn.execute(query)
 
+    @testing.requires.ctes
     def test_warn_anon_cte(self):
         a, b = self.tables("table_a", "table_b")
 
@@ -444,6 +483,47 @@ class TestLinter(fixtures.TablesTest):
             with self.bind.connect() as conn:
                 conn.execute(query)
 
+    @testing.variation(
+        "dml",
+        [
+            ("update", testing.requires.update_from),
+            ("delete", testing.requires.delete_using),
+        ],
+    )
+    @testing.combinations(
+        (False, False), (True, False), (True, True), argnames="twotable,error"
+    )
+    def test_warn_dml(self, dml, twotable, error):
+        a, b = self.tables("table_a", "table_b")
+
+        if dml.update:
+            stmt = update(a).values(col_a=5)
+        elif dml.delete:
+            stmt = delete(a)
+        else:
+            dml.fail()
+
+        stmt = stmt.where(a.c.col_a == 1)
+        if twotable:
+            stmt = stmt.where(b.c.col_b == 1)
+
+            if not error:
+                stmt = stmt.where(b.c.col_b == a.c.col_a)
+
+        stmt_type = "UPDATE" if dml.update else "DELETE"
+
+        with self.bind.connect() as conn:
+            if error:
+                with expect_warnings(
+                    rf"{stmt_type} statement has a cartesian product between "
+                    rf'FROM element\(s\) "table_[ab]" and FROM '
+                    rf'element "table_[ab]"'
+                ):
+                    with self.bind.connect() as conn:
+                        conn.execute(stmt)
+            else:
+                conn.execute(stmt)
+
     def test_no_linting(self, metadata, connection):
         eng = engines.testing_engine(
             options={"enable_from_linting": False, "use_reaper": False}
index ef8f117bcd9b6ab369d05ef62f7e1d3662bda783..d8de5c277b5cb76a2aa9ff58427f23f13ddef3d7 100644 (file)
@@ -113,6 +113,59 @@ class _UpdateFromTestBase:
 class UpdateTest(_UpdateFromTestBase, fixtures.TablesTest, AssertsCompiledSQL):
     __dialect__ = "default_enhanced"
 
+    @testing.variation("twotable", [True, False])
+    @testing.variation("values", ["none", "blank"])
+    def test_update_no_params(self, values, twotable):
+        """test issue identified while doing #9721
+
+        UPDATE with empty VALUES but multiple tables would raise a
+        NoneType error; fixed this to emit an empty "SET" the way a single
+        table UPDATE currently does.
+
+        both cases should probably raise CompileError, however this could
+        be backwards incompatible with current use cases (such as other test
+        suites)
+
+        """
+
+        table1 = self.tables.mytable
+        table2 = self.tables.myothertable
+
+        stmt = table1.update().where(table1.c.name == "jill")
+        if twotable:
+            stmt = stmt.where(table2.c.otherid == table1.c.myid)
+
+        if values.blank:
+            stmt = stmt.values()
+
+        if twotable:
+            if values.blank:
+                self.assert_compile(
+                    stmt,
+                    "UPDATE mytable SET  FROM myothertable "
+                    "WHERE mytable.name = :name_1 "
+                    "AND myothertable.otherid = mytable.myid",
+                )
+            elif values.none:
+                self.assert_compile(
+                    stmt,
+                    "UPDATE mytable SET myid=:myid, name=:name, "
+                    "description=:description FROM myothertable "
+                    "WHERE mytable.name = :name_1 "
+                    "AND myothertable.otherid = mytable.myid",
+                )
+        elif values.blank:
+            self.assert_compile(
+                stmt,
+                "UPDATE mytable SET  WHERE mytable.name = :name_1",
+            )
+        elif values.none:
+            self.assert_compile(
+                stmt,
+                "UPDATE mytable SET myid=:myid, name=:name, "
+                "description=:description WHERE mytable.name = :name_1",
+            )
+
     def test_update_literal_binds(self):
         table1 = self.tables.mytable