From: Mike Bayer Date: Sun, 29 Jan 2023 00:50:25 +0000 (-0500) Subject: don't count / gather INSERT bind names inside of a CTE X-Git-Tag: rel_2_0_1~13^2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=d23dcbaea2a8e000c5fa2ba443e1b683b3b79fa6;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git don't count / gather INSERT bind names inside of a CTE Fixed regression related to the implementation for the new "insertmanyvalues" feature where an internal ``TypeError`` would occur in arrangements where a :func:`_sql.insert` would be referred towards inside of another :func:`_sql.insert` via a CTE; made additional repairs for this use case for positional dialects such as asyncpg when using "insertmanyvalues". at the core here is a change to positional insertmanyvalues where we now get exactly the positions for the "manyvalues" within the larger list, allowing non-"manyvalues" on the left and right sides at the same time, not assuming anything about how RETURNING renders etc., since CTEs are in the mix also. Fixes: #9173 Change-Id: I5ff071fbef0d92a2d6046b9c4e609bb008438afd --- diff --git a/doc/build/changelog/unreleased_20/9173.rst b/doc/build/changelog/unreleased_20/9173.rst new file mode 100644 index 0000000000..0e0f595201 --- /dev/null +++ b/doc/build/changelog/unreleased_20/9173.rst @@ -0,0 +1,12 @@ +.. change:: + :tags: bug, sql, regression + :tickets: 9173 + + Fixed regression related to the implementation for the new + "insertmanyvalues" feature where an internal ``TypeError`` would occur in + arrangements where a :func:`_sql.insert` would be referred towards inside + of another :func:`_sql.insert` via a CTE; made additional repairs for this + use case for positional dialects such as asyncpg when using + "insertmanyvalues". + + diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 2c50081fbf..d4ddc2e5da 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -1545,12 +1545,12 @@ class SQLCompiler(Compiled): self.positiontup = list(param_pos) if self.escaped_bind_names: - reverse_escape = {v: k for k, v in self.escaped_bind_names.items()} - assert len(self.escaped_bind_names) == len(reverse_escape) + len_before = len(param_pos) param_pos = { self.escaped_bind_names.get(name, name): pos for name, pos in param_pos.items() } + assert len(param_pos) == len_before # Can't use format here since % chars are not escaped. self.string = self._pyformat_pattern.sub( @@ -3374,7 +3374,6 @@ class SQLCompiler(Compiled): skip_bind_expression=False, literal_execute=False, render_postcompile=False, - accumulate_bind_names=None, **kwargs, ): if not skip_bind_expression: @@ -3388,7 +3387,6 @@ class SQLCompiler(Compiled): literal_binds=literal_binds and not bindparam.expanding, literal_execute=literal_execute, render_postcompile=render_postcompile, - accumulate_bind_names=accumulate_bind_names, **kwargs, ) if bindparam.expanding: @@ -3490,9 +3488,6 @@ class SQLCompiler(Compiled): self.binds[bindparam.key] = self.binds[name] = bindparam - if accumulate_bind_names is not None: - accumulate_bind_names.add(name) - # if we are given a cache key that we're going to match against, # relate the bindparam here to one that is most likely present # in the "extracted params" portion of the cache key. this is used @@ -3646,11 +3641,19 @@ class SQLCompiler(Compiled): expanding: bool = False, escaped_from: Optional[str] = None, bindparam_type: Optional[TypeEngine[Any]] = None, + accumulate_bind_names: Optional[Set[str]] = None, + visited_bindparam: Optional[List[str]] = None, **kw: Any, ) -> str: - if self._visited_bindparam is not None: - self._visited_bindparam.append(name) + # TODO: accumulate_bind_names is passed by crud.py to gather + # names on a per-value basis, visited_bindparam is passed by + # visit_insert() to collect all parameters in the statement. + # see if this gathering can be simplified somehow + if accumulate_bind_names is not None: + accumulate_bind_names.add(name) + if visited_bindparam is not None: + visited_bindparam.append(name) if not escaped_from: @@ -5086,6 +5089,8 @@ class SQLCompiler(Compiled): assert insert_crud_params is not None escaped_bind_names: Mapping[str, str] + expand_pos_lower_index = expand_pos_upper_index = 0 + if not self.positional: if self.escaped_bind_names: escaped_bind_names = self.escaped_bind_names @@ -5124,6 +5129,31 @@ class SQLCompiler(Compiled): keys_to_replace = set() base_parameters = {} executemany_values_w_comma = f"({imv.single_values_expr}), " + + all_names_we_will_expand: Set[str] = set() + for elem in imv.insert_crud_params: + all_names_we_will_expand.update(elem[3]) + + # get the start and end position in a particular list + # of parameters where we will be doing the "expanding". + # statements can have params on either side or both sides, + # given RETURNING and CTEs + if all_names_we_will_expand: + positiontup = self.positiontup + assert positiontup is not None + + all_expand_positions = { + idx + for idx, name in enumerate(positiontup) + if name in all_names_we_will_expand + } + expand_pos_lower_index = min(all_expand_positions) + expand_pos_upper_index = max(all_expand_positions) + 1 + assert ( + len(all_expand_positions) + == expand_pos_upper_index - expand_pos_lower_index + ) + if self._numeric_binds: escaped = re.escape(self._numeric_binds_identifier_char) executemany_values_w_comma = re.sub( @@ -5149,52 +5179,61 @@ class SQLCompiler(Compiled): replaced_parameters: Any if self.positional: - # the assumption here is that any parameters that are not - # in the VALUES clause are expected to be parameterized - # expressions in the RETURNING (or maybe ON CONFLICT) clause. - # So based on - # which sequence comes first in the compiler's INSERT - # statement tells us where to expand the parameters. - - # otherwise we probably shouldn't be doing insertmanyvalues - # on the statement. - num_ins_params = imv.num_positional_params_counted batch_iterator: Iterable[Tuple[Any, ...]] if num_ins_params == len(batch[0]): - extra_params = () + extra_params_left = extra_params_right = () batch_iterator = batch - elif self.returning_precedes_values or self._numeric_binds: - extra_params = batch[0][:-num_ins_params] - batch_iterator = (b[-num_ins_params:] for b in batch) else: - extra_params = batch[0][num_ins_params:] - batch_iterator = (b[:num_ins_params] for b in batch) + extra_params_left = batch[0][:expand_pos_lower_index] + extra_params_right = batch[0][expand_pos_upper_index:] + batch_iterator = ( + b[expand_pos_lower_index:expand_pos_upper_index] + for b in batch + ) + + expanded_values_string = ( + executemany_values_w_comma * len(batch) + )[:-2] - values_string = (executemany_values_w_comma * len(batch))[:-2] if self._numeric_binds and num_ins_params > 0: + # numeric will always number the parameters inside of + # VALUES (and thus order self.positiontup) to be higher + # than non-VALUES parameters, no matter where in the + # statement those non-VALUES parameters appear (this is + # ensured in _process_numeric by numbering first all + # params that are not in _values_bindparam) + # therefore all extra params are always + # on the left side and numbered lower than the VALUES + # parameters + assert not extra_params_right + + start = expand_pos_lower_index + 1 + end = num_ins_params * (len(batch)) + start + # need to format here, since statement may contain # unescaped %, while values_string contains just (%s, %s) - start = len(extra_params) + 1 - end = num_ins_params * len(batch) + start positions = tuple( f"{self._numeric_binds_identifier_char}{i}" for i in range(start, end) ) - values_string = values_string % positions + expanded_values_string = expanded_values_string % positions replaced_statement = statement.replace( - "__EXECMANY_TOKEN__", values_string + "__EXECMANY_TOKEN__", expanded_values_string ) replaced_parameters = tuple( itertools.chain.from_iterable(batch_iterator) ) - if self.returning_precedes_values or self._numeric_binds: - replaced_parameters = extra_params + replaced_parameters - else: - replaced_parameters = replaced_parameters + extra_params + + replaced_parameters = ( + extra_params_left + + replaced_parameters + + extra_params_right + ) + else: replaced_values_clauses = [] replaced_parameters = base_parameters.copy() @@ -5224,7 +5263,7 @@ class SQLCompiler(Compiled): ) batchnum += 1 - def visit_insert(self, insert_stmt, **kw): + def visit_insert(self, insert_stmt, visited_bindparam=None, **kw): compile_state = insert_stmt._compile_state_factory( insert_stmt, self, **kw @@ -5250,6 +5289,9 @@ class SQLCompiler(Compiled): counted_bindparam = 0 + # reset any incoming "visited_bindparam" collection + visited_bindparam = None + # for positional, insertmanyvalues needs to know how many # bound parameters are in the VALUES sequence; there's no simple # rule because default expressions etc. can have zero or more @@ -5257,21 +5299,30 @@ class SQLCompiler(Compiled): # this very simplistic "count after" works and is # likely the least amount of callcounts, though looks clumsy if self.positional: - self._visited_bindparam = [] + # if we are inside a CTE, don't count parameters + # here since they wont be for insertmanyvalues. keep + # visited_bindparam at None so no counting happens. + # see #9173 + has_visiting_cte = "visiting_cte" in kw + if not has_visiting_cte: + visited_bindparam = [] crud_params_struct = crud._get_crud_params( - self, insert_stmt, compile_state, toplevel, **kw + self, + insert_stmt, + compile_state, + toplevel, + visited_bindparam=visited_bindparam, + **kw, ) - if self.positional: - assert self._visited_bindparam is not None - counted_bindparam = len(self._visited_bindparam) + if self.positional and visited_bindparam is not None: + counted_bindparam = len(visited_bindparam) if self._numeric_binds: if self._values_bindparam is not None: - self._values_bindparam += self._visited_bindparam + self._values_bindparam += visited_bindparam else: - self._values_bindparam = self._visited_bindparam - self._visited_bindparam = None + self._values_bindparam = visited_bindparam crud_params_single = crud_params_struct.single_params diff --git a/lib/sqlalchemy/sql/crud.py b/lib/sqlalchemy/sql/crud.py index 5017afa78e..04b62d1ffd 100644 --- a/lib/sqlalchemy/sql/crud.py +++ b/lib/sqlalchemy/sql/crud.py @@ -150,6 +150,17 @@ def _get_crud_params( compiler.update_prefetch = [] compiler.implicit_returning = [] + visiting_cte = kw.get("visiting_cte", None) + if visiting_cte is not None: + # for insert -> CTE -> insert, don't populate an incoming + # _crud_accumulate_bind_names collection; the INSERT we process here + # will not be inline within the VALUES of the enclosing INSERT as the + # CTE is placed on the outside. See issue #9173 + kw.pop("accumulate_bind_names", None) + assert ( + "accumulate_bind_names" not in kw + ), "Don't know how to handle insert within insert without a CTE" + # getters - these are normally just column.key, # but in the case of mysql multi-table update, the rules for # .key must conditionally take tablename into account diff --git a/test/sql/test_cte.py b/test/sql/test_cte.py index 502104daea..4ba4eddfeb 100644 --- a/test/sql/test_cte.py +++ b/test/sql/test_cte.py @@ -1317,6 +1317,72 @@ class CTETest(fixtures.TestBase, AssertsCompiledSQL): dialect=dialect, ) + @testing.combinations( + ("default_enhanced",), + ("postgresql",), + ("postgresql+asyncpg",), + ) + def test_insert_w_cte_in_scalar_subquery(self, dialect): + """test #9173""" + + customer = table( + "customer", + column("id"), + column("name"), + ) + order = table( + "order", + column("id"), + column("price"), + column("customer_id"), + ) + + inst = ( + customer.insert() + .values(name="John") + .returning(customer.c.id) + .cte("inst") + ) + + stmt = ( + order.insert() + .values( + price=1, + customer_id=select(inst.c.id).scalar_subquery(), + ) + .add_cte(inst) + ) + + if dialect == "default_enhanced": + self.assert_compile( + stmt, + "WITH inst AS (INSERT INTO customer (name) VALUES (:param_1) " + 'RETURNING customer.id) INSERT INTO "order" ' + "(price, customer_id) VALUES " + "(:price, (SELECT inst.id FROM inst))", + dialect=dialect, + ) + elif dialect == "postgresql": + self.assert_compile( + stmt, + "WITH inst AS (INSERT INTO customer (name) " + "VALUES (%(param_1)s) " + 'RETURNING customer.id) INSERT INTO "order" ' + "(price, customer_id) " + "VALUES (%(price)s, (SELECT inst.id FROM inst))", + dialect=dialect, + ) + elif dialect == "postgresql+asyncpg": + self.assert_compile( + stmt, + "WITH inst AS (INSERT INTO customer (name) VALUES ($2) " + 'RETURNING customer.id) INSERT INTO "order" ' + "(price, customer_id) VALUES ($1, (SELECT inst.id FROM inst))", + dialect=dialect, + ) + else: + assert False + @testing.combinations( ("default_enhanced",), ("postgresql",), diff --git a/test/sql/test_insert_exec.py b/test/sql/test_insert_exec.py index d9dac75b33..3b5a1856cd 100644 --- a/test/sql/test_insert_exec.py +++ b/test/sql/test_insert_exec.py @@ -23,6 +23,7 @@ from sqlalchemy.testing import expect_raises_message from sqlalchemy.testing import fixtures from sqlalchemy.testing import is_ from sqlalchemy.testing import mock +from sqlalchemy.testing import provision from sqlalchemy.testing.provision import normalize_sequence from sqlalchemy.testing.schema import Column from sqlalchemy.testing.schema import Table @@ -825,6 +826,119 @@ class InsertManyValuesTest(fixtures.RemovesEvents, fixtures.TablesTest): eq_(result.inserted_primary_key_rows, [(1,), (2,), (3,)]) + @testing.requires.ctes_on_dml + @testing.variation("add_expr_returning", [True, False]) + def test_insert_w_bindparam_in_nested_insert( + self, connection, add_expr_returning + ): + """test related to #9173""" + + data, extra_table = self.tables("data", "extra_table") + + inst = ( + extra_table.insert() + .values(x_value="x", y_value="y") + .returning(extra_table.c.id) + .cte("inst") + ) + + stmt = ( + data.insert() + .values(x="the x", z=select(inst.c.id).scalar_subquery()) + .add_cte(inst) + ) + + if add_expr_returning: + stmt = stmt.returning(data.c.id, data.c.y + " returned y") + else: + stmt = stmt.returning(data.c.id) + + result = connection.execute( + stmt, + [ + {"y": "y1"}, + {"y": "y2"}, + {"y": "y3"}, + ], + ) + + result_rows = result.all() + + ids = [row[0] for row in result_rows] + + extra_row = connection.execute( + select(extra_table).order_by(extra_table.c.id) + ).one() + extra_row_id = extra_row[0] + eq_(extra_row, (extra_row_id, "x", "y")) + eq_( + connection.execute(select(data).order_by(data.c.id)).all(), + [ + (ids[0], "the x", "y1", extra_row_id), + (ids[1], "the x", "y2", extra_row_id), + (ids[2], "the x", "y3", extra_row_id), + ], + ) + + @testing.requires.provisioned_upsert + def test_upsert_w_returning(self, connection): + """test cases that will execise SQL similar to that of + test/orm/dml/test_bulk_statements.py + + """ + + data = self.tables.data + + initial_data = [ + {"x": "x1", "y": "y1", "z": 4}, + {"x": "x2", "y": "y2", "z": 8}, + ] + ids = connection.scalars( + data.insert().returning(data.c.id), initial_data + ).all() + + upsert_data = [ + { + "id": ids[0], + "x": "x1", + "y": "y1", + }, + { + "id": 32, + "x": "x19", + "y": "y7", + }, + { + "id": ids[1], + "x": "x5", + "y": "y6", + }, + { + "id": 28, + "x": "x9", + "y": "y15", + }, + ] + + stmt = provision.upsert( + config, + data, + (data,), + lambda inserted: {"x": inserted.x + " upserted"}, + ) + + result = connection.execute(stmt, upsert_data) + + eq_( + result.all(), + [ + (ids[0], "x1 upserted", "y1", 4), + (32, "x19", "y7", 5), + (ids[1], "x5 upserted", "y2", 8), + (28, "x9", "y15", 5), + ], + ) + @testing.combinations(True, False, argnames="use_returning") @testing.combinations(1, 2, argnames="num_embedded_params") @testing.combinations(True, False, argnames="use_whereclause") @@ -835,7 +949,11 @@ class InsertManyValuesTest(fixtures.RemovesEvents, fixtures.TablesTest): def test_insert_w_bindparam_in_subq( self, connection, use_returning, num_embedded_params, use_whereclause ): - """test #8639""" + """test #8639 + + see also test_insert_w_bindparam_in_nested_insert + + """ t = self.tables.data extra = self.tables.extra_table