]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
copy stack related elements to str compiler
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 12 Dec 2023 19:57:38 +0000 (14:57 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 12 Dec 2023 22:07:12 +0000 (17:07 -0500)
Fixed issue in stringify for SQL elements, where a specific dialect is not
passed,  where a dialect-specific element such as the PostgreSQL "on
conflict do update" construct is encountered and then fails to provide for
a stringify dialect with the appropriate state to render the construct,
leading to internal errors.

Fixed issue where stringifying or compiling a :class:`.CTE` that was
against a DML construct such as an :func:`_sql.insert` construct would fail
to stringify, due to a mis-detection that the statement overall is an
INSERT, leading to internal errors.

Fixes: #10753
Change-Id: I783eca3fc7bbc1794fedd325d58181dbcc7e0b75

doc/build/changelog/unreleased_20/10753.rst [new file with mode: 0644]
lib/sqlalchemy/sql/compiler.py
test/sql/test_compiler.py
test/sql/test_cte.py

diff --git a/doc/build/changelog/unreleased_20/10753.rst b/doc/build/changelog/unreleased_20/10753.rst
new file mode 100644 (file)
index 0000000..5b714ed
--- /dev/null
@@ -0,0 +1,17 @@
+.. change::
+    :tags: bug, sql
+    :tickets: 10753
+
+    Fixed issue in stringify for SQL elements, where a specific dialect is not
+    passed,  where a dialect-specific element such as the PostgreSQL "on
+    conflict do update" construct is encountered and then fails to provide for
+    a stringify dialect with the appropriate state to render the construct,
+    leading to internal errors.
+
+.. change::
+    :tags: bug, sql
+
+    Fixed issue where stringifying or compiling a :class:`.CTE` that was
+    against a DML construct such as an :func:`_sql.insert` construct would fail
+    to stringify, due to a mis-detection that the statement overall is an
+    INSERT, leading to internal errors.
index cb6899c5e9a8761a321a54a232c39a61f293cc3b..b4b8bcfd26e1102d69e16f04a86d6f2dbc3ec0aa 100644 (file)
@@ -1343,6 +1343,7 @@ class SQLCompiler(Compiled):
         column_keys: Optional[Sequence[str]] = None,
         for_executemany: bool = False,
         linting: Linting = NO_LINTING,
+        _supporting_against: Optional[SQLCompiler] = None,
         **kwargs: Any,
     ):
         """Construct a new :class:`.SQLCompiler` object.
@@ -1445,6 +1446,24 @@ class SQLCompiler(Compiled):
 
         self.bindtemplate = BIND_TEMPLATES[dialect.paramstyle]
 
+        if _supporting_against:
+            self.__dict__.update(
+                {
+                    k: v
+                    for k, v in _supporting_against.__dict__.items()
+                    if k
+                    not in {
+                        "state",
+                        "dialect",
+                        "preparer",
+                        "positional",
+                        "_numeric_binds",
+                        "compilation_bindtemplate",
+                        "bindtemplate",
+                    }
+                }
+            )
+
         if self.state is CompilerState.STRING_APPLIED:
             if self.positional:
                 if self._numeric_binds:
@@ -5595,13 +5614,19 @@ class SQLCompiler(Compiled):
             )
             batchnum += 1
 
-    def visit_insert(self, insert_stmt, visited_bindparam=None, **kw):
+    def visit_insert(
+        self, insert_stmt, visited_bindparam=None, visiting_cte=None, **kw
+    ):
         compile_state = insert_stmt._compile_state_factory(
             insert_stmt, self, **kw
         )
         insert_stmt = compile_state.statement
 
-        toplevel = not self.stack
+        if visiting_cte is not None:
+            kw["visiting_cte"] = visiting_cte
+            toplevel = False
+        else:
+            toplevel = not self.stack
 
         if toplevel:
             self.isinsert = True
@@ -5629,14 +5654,12 @@ class SQLCompiler(Compiled):
         # params inside them.   After multiple attempts to figure this out,
         # this very simplistic "count after" works and is
         # likely the least amount of callcounts, though looks clumsy
-        if self.positional:
+        if self.positional and visiting_cte is None:
             # if we are inside a CTE, don't count parameters
             # here since they wont be for insertmanyvalues. keep
             # visited_bindparam at None so no counting happens.
             # see #9173
-            has_visiting_cte = "visiting_cte" in kw
-            if not has_visiting_cte:
-                visited_bindparam = []
+            visited_bindparam = []
 
         crud_params_struct = crud._get_crud_params(
             self,
@@ -5990,13 +6013,18 @@ class SQLCompiler(Compiled):
             "criteria within UPDATE"
         )
 
-    def visit_update(self, update_stmt, **kw):
+    def visit_update(self, update_stmt, visiting_cte=None, **kw):
         compile_state = update_stmt._compile_state_factory(
             update_stmt, self, **kw
         )
         update_stmt = compile_state.statement
 
-        toplevel = not self.stack
+        if visiting_cte is not None:
+            kw["visiting_cte"] = visiting_cte
+            toplevel = False
+        else:
+            toplevel = not self.stack
+
         if toplevel:
             self.isupdate = True
             if not self.dml_compile_state:
@@ -6147,13 +6175,18 @@ class SQLCompiler(Compiled):
             self, asfrom=True, iscrud=True, **kw
         )
 
-    def visit_delete(self, delete_stmt, **kw):
+    def visit_delete(self, delete_stmt, visiting_cte=None, **kw):
         compile_state = delete_stmt._compile_state_factory(
             delete_stmt, self, **kw
         )
         delete_stmt = compile_state.statement
 
-        toplevel = not self.stack
+        if visiting_cte is not None:
+            kw["visiting_cte"] = visiting_cte
+            toplevel = False
+        else:
+            toplevel = not self.stack
+
         if toplevel:
             self.isdelete = True
             if not self.dml_compile_state:
@@ -6312,9 +6345,11 @@ class StrSQLCompiler(SQLCompiler):
             url = util.preloaded.engine_url
             dialect = url.URL.create(element.stringify_dialect).get_dialect()()
 
-            compiler = dialect.statement_compiler(dialect, None)
+            compiler = dialect.statement_compiler(
+                dialect, None, _supporting_against=self
+            )
             if not isinstance(compiler, StrSQLCompiler):
-                return compiler.process(element)
+                return compiler.process(element, **kw)
 
         return super().visit_unsupported_compilation(element, err)
 
index 3bd1bacc6d8855982c541c6e4f5216c02f90ae76..d6bc098964c777e10f8375db6250026a8c375467 100644 (file)
@@ -5974,6 +5974,53 @@ class StringifySpecialTest(fixtures.TestBase):
         ):
             eq_(str(Grouping(Widget())), "(widget)")
 
+    def test_dialect_sub_compile_has_stack(self):
+        """test #10753"""
+
+        class Widget(ColumnElement):
+            __visit_name__ = "widget"
+            stringify_dialect = "sqlite"
+
+        def visit_widget(self, element, **kw):
+            assert self.stack
+            return "widget"
+
+        with mock.patch(
+            "sqlalchemy.dialects.sqlite.base.SQLiteCompiler.visit_widget",
+            visit_widget,
+            create=True,
+        ):
+            eq_(str(select(Widget())), "SELECT widget AS anon_1")
+
+    def test_dialect_sub_compile_has_stack_pg_specific(self):
+        """test #10753"""
+        my_table = table(
+            "my_table", column("id"), column("data"), column("user_email")
+        )
+
+        from sqlalchemy.dialects.postgresql import insert
+
+        insert_stmt = insert(my_table).values(
+            id="some_existing_id", data="inserted value"
+        )
+
+        do_update_stmt = insert_stmt.on_conflict_do_update(
+            index_elements=["id"], set_=dict(data="updated value")
+        )
+
+        # note!  two different bound parameter formats.   It's weird yes,
+        # but this is what I want.  They are stringifying without using the
+        # correct dialect.   We could use the PG compiler at the point of
+        # the insert() but that still would not accommodate params in other
+        # parts of the statement.
+        eq_ignore_whitespace(
+            str(select(do_update_stmt.cte())),
+            "WITH anon_1 AS (INSERT INTO my_table (id, data) "
+            "VALUES (:param_1, :param_2) "
+            "ON CONFLICT (id) "
+            "DO UPDATE SET data = %(param_3)s) SELECT FROM anon_1",
+        )
+
     def test_dialect_sub_compile_w_binds(self):
         """test sub-compile into a new compiler where
         state != CompilerState.COMPILING, but we have to render a bindparam
index d044212aa60230069cee8782f6f835065d38cf7c..23ac87a2148d410f3c58b47e0b23a16bbd198f0a 100644 (file)
@@ -1383,6 +1383,36 @@ class CTETest(fixtures.TestBase, AssertsCompiledSQL):
         else:
             assert False
 
+    @testing.variation("operation", ["insert", "update", "delete"])
+    def test_stringify_standalone_dml_cte(self, operation):
+        """test issue discovered as part of #10753"""
+
+        t1 = table("table_1", column("id"), column("val"))
+
+        if operation.insert:
+            stmt = t1.insert()
+            expected = (
+                "INSERT INTO table_1 (id, val) VALUES (:id, :val) "
+                "RETURNING table_1.id, table_1.val"
+            )
+        elif operation.update:
+            stmt = t1.update()
+            expected = (
+                "UPDATE table_1 SET id=:id, val=:val "
+                "RETURNING table_1.id, table_1.val"
+            )
+        elif operation.delete:
+            stmt = t1.delete()
+            expected = "DELETE FROM table_1 RETURNING table_1.id, table_1.val"
+        else:
+            operation.fail()
+
+        stmt = stmt.returning(t1.c.id, t1.c.val)
+
+        cte = stmt.cte()
+
+        self.assert_compile(cte, expected)
+
     @testing.combinations(
         ("default_enhanced",),
         ("postgresql",),