From: Mike Bayer Date: Tue, 12 Dec 2023 19:57:38 +0000 (-0500) Subject: copy stack related elements to str compiler X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=0248efb761bec4bdcea76bc6bbe3c09934f6b527;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git copy stack related elements to str compiler Fixed issue in stringify for SQL elements, where a specific dialect is not passed, where a dialect-specific element such as the PostgreSQL "on conflict do update" construct is encountered and then fails to provide for a stringify dialect with the appropriate state to render the construct, leading to internal errors. Fixed issue where stringifying or compiling a :class:`.CTE` that was against a DML construct such as an :func:`_sql.insert` construct would fail to stringify, due to a mis-detection that the statement overall is an INSERT, leading to internal errors. Fixes: #10753 Change-Id: I783eca3fc7bbc1794fedd325d58181dbcc7e0b75 --- diff --git a/doc/build/changelog/unreleased_20/10753.rst b/doc/build/changelog/unreleased_20/10753.rst new file mode 100644 index 0000000000..5b714ed197 --- /dev/null +++ b/doc/build/changelog/unreleased_20/10753.rst @@ -0,0 +1,17 @@ +.. change:: + :tags: bug, sql + :tickets: 10753 + + Fixed issue in stringify for SQL elements, where a specific dialect is not + passed, where a dialect-specific element such as the PostgreSQL "on + conflict do update" construct is encountered and then fails to provide for + a stringify dialect with the appropriate state to render the construct, + leading to internal errors. + +.. change:: + :tags: bug, sql + + Fixed issue where stringifying or compiling a :class:`.CTE` that was + against a DML construct such as an :func:`_sql.insert` construct would fail + to stringify, due to a mis-detection that the statement overall is an + INSERT, leading to internal errors. diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index cb6899c5e9..b4b8bcfd26 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -1343,6 +1343,7 @@ class SQLCompiler(Compiled): column_keys: Optional[Sequence[str]] = None, for_executemany: bool = False, linting: Linting = NO_LINTING, + _supporting_against: Optional[SQLCompiler] = None, **kwargs: Any, ): """Construct a new :class:`.SQLCompiler` object. @@ -1445,6 +1446,24 @@ class SQLCompiler(Compiled): self.bindtemplate = BIND_TEMPLATES[dialect.paramstyle] + if _supporting_against: + self.__dict__.update( + { + k: v + for k, v in _supporting_against.__dict__.items() + if k + not in { + "state", + "dialect", + "preparer", + "positional", + "_numeric_binds", + "compilation_bindtemplate", + "bindtemplate", + } + } + ) + if self.state is CompilerState.STRING_APPLIED: if self.positional: if self._numeric_binds: @@ -5595,13 +5614,19 @@ class SQLCompiler(Compiled): ) batchnum += 1 - def visit_insert(self, insert_stmt, visited_bindparam=None, **kw): + def visit_insert( + self, insert_stmt, visited_bindparam=None, visiting_cte=None, **kw + ): compile_state = insert_stmt._compile_state_factory( insert_stmt, self, **kw ) insert_stmt = compile_state.statement - toplevel = not self.stack + if visiting_cte is not None: + kw["visiting_cte"] = visiting_cte + toplevel = False + else: + toplevel = not self.stack if toplevel: self.isinsert = True @@ -5629,14 +5654,12 @@ class SQLCompiler(Compiled): # params inside them. After multiple attempts to figure this out, # this very simplistic "count after" works and is # likely the least amount of callcounts, though looks clumsy - if self.positional: + if self.positional and visiting_cte is None: # if we are inside a CTE, don't count parameters # here since they wont be for insertmanyvalues. keep # visited_bindparam at None so no counting happens. # see #9173 - has_visiting_cte = "visiting_cte" in kw - if not has_visiting_cte: - visited_bindparam = [] + visited_bindparam = [] crud_params_struct = crud._get_crud_params( self, @@ -5990,13 +6013,18 @@ class SQLCompiler(Compiled): "criteria within UPDATE" ) - def visit_update(self, update_stmt, **kw): + def visit_update(self, update_stmt, visiting_cte=None, **kw): compile_state = update_stmt._compile_state_factory( update_stmt, self, **kw ) update_stmt = compile_state.statement - toplevel = not self.stack + if visiting_cte is not None: + kw["visiting_cte"] = visiting_cte + toplevel = False + else: + toplevel = not self.stack + if toplevel: self.isupdate = True if not self.dml_compile_state: @@ -6147,13 +6175,18 @@ class SQLCompiler(Compiled): self, asfrom=True, iscrud=True, **kw ) - def visit_delete(self, delete_stmt, **kw): + def visit_delete(self, delete_stmt, visiting_cte=None, **kw): compile_state = delete_stmt._compile_state_factory( delete_stmt, self, **kw ) delete_stmt = compile_state.statement - toplevel = not self.stack + if visiting_cte is not None: + kw["visiting_cte"] = visiting_cte + toplevel = False + else: + toplevel = not self.stack + if toplevel: self.isdelete = True if not self.dml_compile_state: @@ -6312,9 +6345,11 @@ class StrSQLCompiler(SQLCompiler): url = util.preloaded.engine_url dialect = url.URL.create(element.stringify_dialect).get_dialect()() - compiler = dialect.statement_compiler(dialect, None) + compiler = dialect.statement_compiler( + dialect, None, _supporting_against=self + ) if not isinstance(compiler, StrSQLCompiler): - return compiler.process(element) + return compiler.process(element, **kw) return super().visit_unsupported_compilation(element, err) diff --git a/test/sql/test_compiler.py b/test/sql/test_compiler.py index 3bd1bacc6d..d6bc098964 100644 --- a/test/sql/test_compiler.py +++ b/test/sql/test_compiler.py @@ -5974,6 +5974,53 @@ class StringifySpecialTest(fixtures.TestBase): ): eq_(str(Grouping(Widget())), "(widget)") + def test_dialect_sub_compile_has_stack(self): + """test #10753""" + + class Widget(ColumnElement): + __visit_name__ = "widget" + stringify_dialect = "sqlite" + + def visit_widget(self, element, **kw): + assert self.stack + return "widget" + + with mock.patch( + "sqlalchemy.dialects.sqlite.base.SQLiteCompiler.visit_widget", + visit_widget, + create=True, + ): + eq_(str(select(Widget())), "SELECT widget AS anon_1") + + def test_dialect_sub_compile_has_stack_pg_specific(self): + """test #10753""" + my_table = table( + "my_table", column("id"), column("data"), column("user_email") + ) + + from sqlalchemy.dialects.postgresql import insert + + insert_stmt = insert(my_table).values( + id="some_existing_id", data="inserted value" + ) + + do_update_stmt = insert_stmt.on_conflict_do_update( + index_elements=["id"], set_=dict(data="updated value") + ) + + # note! two different bound parameter formats. It's weird yes, + # but this is what I want. They are stringifying without using the + # correct dialect. We could use the PG compiler at the point of + # the insert() but that still would not accommodate params in other + # parts of the statement. + eq_ignore_whitespace( + str(select(do_update_stmt.cte())), + "WITH anon_1 AS (INSERT INTO my_table (id, data) " + "VALUES (:param_1, :param_2) " + "ON CONFLICT (id) " + "DO UPDATE SET data = %(param_3)s) SELECT FROM anon_1", + ) + def test_dialect_sub_compile_w_binds(self): """test sub-compile into a new compiler where state != CompilerState.COMPILING, but we have to render a bindparam diff --git a/test/sql/test_cte.py b/test/sql/test_cte.py index d044212aa6..23ac87a214 100644 --- a/test/sql/test_cte.py +++ b/test/sql/test_cte.py @@ -1383,6 +1383,36 @@ class CTETest(fixtures.TestBase, AssertsCompiledSQL): else: assert False + @testing.variation("operation", ["insert", "update", "delete"]) + def test_stringify_standalone_dml_cte(self, operation): + """test issue discovered as part of #10753""" + + t1 = table("table_1", column("id"), column("val")) + + if operation.insert: + stmt = t1.insert() + expected = ( + "INSERT INTO table_1 (id, val) VALUES (:id, :val) " + "RETURNING table_1.id, table_1.val" + ) + elif operation.update: + stmt = t1.update() + expected = ( + "UPDATE table_1 SET id=:id, val=:val " + "RETURNING table_1.id, table_1.val" + ) + elif operation.delete: + stmt = t1.delete() + expected = "DELETE FROM table_1 RETURNING table_1.id, table_1.val" + else: + operation.fail() + + stmt = stmt.returning(t1.c.id, t1.c.val) + + cte = stmt.cte() + + self.assert_compile(cte, expected) + @testing.combinations( ("default_enhanced",), ("postgresql",),