column_keys: Optional[Sequence[str]] = None,
for_executemany: bool = False,
linting: Linting = NO_LINTING,
+ _supporting_against: Optional[SQLCompiler] = None,
**kwargs: Any,
):
"""Construct a new :class:`.SQLCompiler` object.
self.bindtemplate = BIND_TEMPLATES[dialect.paramstyle]
+ if _supporting_against:
+ self.__dict__.update(
+ {
+ k: v
+ for k, v in _supporting_against.__dict__.items()
+ if k
+ not in {
+ "state",
+ "dialect",
+ "preparer",
+ "positional",
+ "_numeric_binds",
+ "compilation_bindtemplate",
+ "bindtemplate",
+ }
+ }
+ )
+
if self.state is CompilerState.STRING_APPLIED:
if self.positional:
if self._numeric_binds:
)
batchnum += 1
- def visit_insert(self, insert_stmt, visited_bindparam=None, **kw):
+ def visit_insert(
+ self, insert_stmt, visited_bindparam=None, visiting_cte=None, **kw
+ ):
compile_state = insert_stmt._compile_state_factory(
insert_stmt, self, **kw
)
insert_stmt = compile_state.statement
- toplevel = not self.stack
+ if visiting_cte is not None:
+ kw["visiting_cte"] = visiting_cte
+ toplevel = False
+ else:
+ toplevel = not self.stack
if toplevel:
self.isinsert = True
# params inside them. After multiple attempts to figure this out,
# this very simplistic "count after" works and is
# likely the least amount of callcounts, though looks clumsy
- if self.positional:
+ if self.positional and visiting_cte is None:
# if we are inside a CTE, don't count parameters
# here since they wont be for insertmanyvalues. keep
# visited_bindparam at None so no counting happens.
# see #9173
- has_visiting_cte = "visiting_cte" in kw
- if not has_visiting_cte:
- visited_bindparam = []
+ visited_bindparam = []
crud_params_struct = crud._get_crud_params(
self,
"criteria within UPDATE"
)
- def visit_update(self, update_stmt, **kw):
+ def visit_update(self, update_stmt, visiting_cte=None, **kw):
compile_state = update_stmt._compile_state_factory(
update_stmt, self, **kw
)
update_stmt = compile_state.statement
- toplevel = not self.stack
+ if visiting_cte is not None:
+ kw["visiting_cte"] = visiting_cte
+ toplevel = False
+ else:
+ toplevel = not self.stack
+
if toplevel:
self.isupdate = True
if not self.dml_compile_state:
self, asfrom=True, iscrud=True, **kw
)
- def visit_delete(self, delete_stmt, **kw):
+ def visit_delete(self, delete_stmt, visiting_cte=None, **kw):
compile_state = delete_stmt._compile_state_factory(
delete_stmt, self, **kw
)
delete_stmt = compile_state.statement
- toplevel = not self.stack
+ if visiting_cte is not None:
+ kw["visiting_cte"] = visiting_cte
+ toplevel = False
+ else:
+ toplevel = not self.stack
+
if toplevel:
self.isdelete = True
if not self.dml_compile_state:
url = util.preloaded.engine_url
dialect = url.URL.create(element.stringify_dialect).get_dialect()()
- compiler = dialect.statement_compiler(dialect, None)
+ compiler = dialect.statement_compiler(
+ dialect, None, _supporting_against=self
+ )
if not isinstance(compiler, StrSQLCompiler):
- return compiler.process(element)
+ return compiler.process(element, **kw)
return super().visit_unsupported_compilation(element, err)
):
eq_(str(Grouping(Widget())), "(widget)")
+ def test_dialect_sub_compile_has_stack(self):
+ """test #10753"""
+
+ class Widget(ColumnElement):
+ __visit_name__ = "widget"
+ stringify_dialect = "sqlite"
+
+ def visit_widget(self, element, **kw):
+ assert self.stack
+ return "widget"
+
+ with mock.patch(
+ "sqlalchemy.dialects.sqlite.base.SQLiteCompiler.visit_widget",
+ visit_widget,
+ create=True,
+ ):
+ eq_(str(select(Widget())), "SELECT widget AS anon_1")
+
+ def test_dialect_sub_compile_has_stack_pg_specific(self):
+ """test #10753"""
+ my_table = table(
+ "my_table", column("id"), column("data"), column("user_email")
+ )
+
+ from sqlalchemy.dialects.postgresql import insert
+
+ insert_stmt = insert(my_table).values(
+ id="some_existing_id", data="inserted value"
+ )
+
+ do_update_stmt = insert_stmt.on_conflict_do_update(
+ index_elements=["id"], set_=dict(data="updated value")
+ )
+
+ # note! two different bound parameter formats. It's weird yes,
+ # but this is what I want. They are stringifying without using the
+ # correct dialect. We could use the PG compiler at the point of
+ # the insert() but that still would not accommodate params in other
+ # parts of the statement.
+ eq_ignore_whitespace(
+ str(select(do_update_stmt.cte())),
+ "WITH anon_1 AS (INSERT INTO my_table (id, data) "
+ "VALUES (:param_1, :param_2) "
+ "ON CONFLICT (id) "
+ "DO UPDATE SET data = %(param_3)s) SELECT FROM anon_1",
+ )
+
def test_dialect_sub_compile_w_binds(self):
"""test sub-compile into a new compiler where
state != CompilerState.COMPILING, but we have to render a bindparam
else:
assert False
+ @testing.variation("operation", ["insert", "update", "delete"])
+ def test_stringify_standalone_dml_cte(self, operation):
+ """test issue discovered as part of #10753"""
+
+ t1 = table("table_1", column("id"), column("val"))
+
+ if operation.insert:
+ stmt = t1.insert()
+ expected = (
+ "INSERT INTO table_1 (id, val) VALUES (:id, :val) "
+ "RETURNING table_1.id, table_1.val"
+ )
+ elif operation.update:
+ stmt = t1.update()
+ expected = (
+ "UPDATE table_1 SET id=:id, val=:val "
+ "RETURNING table_1.id, table_1.val"
+ )
+ elif operation.delete:
+ stmt = t1.delete()
+ expected = "DELETE FROM table_1 RETURNING table_1.id, table_1.val"
+ else:
+ operation.fail()
+
+ stmt = stmt.returning(t1.c.id, t1.c.val)
+
+ cte = stmt.cte()
+
+ self.assert_compile(cte, expected)
+
@testing.combinations(
("default_enhanced",),
("postgresql",),