]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
don't count / gather INSERT bind names inside of a CTE
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 29 Jan 2023 00:50:25 +0000 (19:50 -0500)
committerFederico Caselli <cfederico87@gmail.com>
Mon, 30 Jan 2023 21:28:53 +0000 (22:28 +0100)
Fixed regression related to the implementation for the new
"insertmanyvalues" feature where an internal ``TypeError`` would occur in
arrangements where a :func:`_sql.insert` would be referred towards inside
of another :func:`_sql.insert` via a CTE; made additional repairs for this
use case for positional dialects such as asyncpg when using
"insertmanyvalues".

at the core here is a change to positional insertmanyvalues
where we now get exactly the positions for the "manyvalues" within
the larger list, allowing non-"manyvalues" on the left and right
sides at the same time, not assuming anything about how RETURNING
renders etc., since CTEs are in the mix also.

Fixes: #9173
Change-Id: I5ff071fbef0d92a2d6046b9c4e609bb008438afd

doc/build/changelog/unreleased_20/9173.rst [new file with mode: 0644]
lib/sqlalchemy/sql/compiler.py
lib/sqlalchemy/sql/crud.py
test/sql/test_cte.py
test/sql/test_insert_exec.py

diff --git a/doc/build/changelog/unreleased_20/9173.rst b/doc/build/changelog/unreleased_20/9173.rst
new file mode 100644 (file)
index 0000000..0e0f595
--- /dev/null
@@ -0,0 +1,12 @@
+.. change::
+    :tags: bug, sql, regression
+    :tickets: 9173
+
+    Fixed regression related to the implementation for the new
+    "insertmanyvalues" feature where an internal ``TypeError`` would occur in
+    arrangements where a :func:`_sql.insert` would be referred towards inside
+    of another :func:`_sql.insert` via a CTE; made additional repairs for this
+    use case for positional dialects such as asyncpg when using
+    "insertmanyvalues".
+
+
index 2c50081fbf8d199c484d1da05b484e62426e0674..d4ddc2e5da3f473fba3c9a02d2cff55b004282a9 100644 (file)
@@ -1545,12 +1545,12 @@ class SQLCompiler(Compiled):
 
         self.positiontup = list(param_pos)
         if self.escaped_bind_names:
-            reverse_escape = {v: k for k, v in self.escaped_bind_names.items()}
-            assert len(self.escaped_bind_names) == len(reverse_escape)
+            len_before = len(param_pos)
             param_pos = {
                 self.escaped_bind_names.get(name, name): pos
                 for name, pos in param_pos.items()
             }
+            assert len(param_pos) == len_before
 
         # Can't use format here since % chars are not escaped.
         self.string = self._pyformat_pattern.sub(
@@ -3374,7 +3374,6 @@ class SQLCompiler(Compiled):
         skip_bind_expression=False,
         literal_execute=False,
         render_postcompile=False,
-        accumulate_bind_names=None,
         **kwargs,
     ):
         if not skip_bind_expression:
@@ -3388,7 +3387,6 @@ class SQLCompiler(Compiled):
                     literal_binds=literal_binds and not bindparam.expanding,
                     literal_execute=literal_execute,
                     render_postcompile=render_postcompile,
-                    accumulate_bind_names=accumulate_bind_names,
                     **kwargs,
                 )
                 if bindparam.expanding:
@@ -3490,9 +3488,6 @@ class SQLCompiler(Compiled):
 
         self.binds[bindparam.key] = self.binds[name] = bindparam
 
-        if accumulate_bind_names is not None:
-            accumulate_bind_names.add(name)
-
         # if we are given a cache key that we're going to match against,
         # relate the bindparam here to one that is most likely present
         # in the "extracted params" portion of the cache key.  this is used
@@ -3646,11 +3641,19 @@ class SQLCompiler(Compiled):
         expanding: bool = False,
         escaped_from: Optional[str] = None,
         bindparam_type: Optional[TypeEngine[Any]] = None,
+        accumulate_bind_names: Optional[Set[str]] = None,
+        visited_bindparam: Optional[List[str]] = None,
         **kw: Any,
     ) -> str:
 
-        if self._visited_bindparam is not None:
-            self._visited_bindparam.append(name)
+        # TODO: accumulate_bind_names is passed by crud.py to gather
+        # names on a per-value basis, visited_bindparam is passed by
+        # visit_insert() to collect all parameters in the statement.
+        # see if this gathering can be simplified somehow
+        if accumulate_bind_names is not None:
+            accumulate_bind_names.add(name)
+        if visited_bindparam is not None:
+            visited_bindparam.append(name)
 
         if not escaped_from:
 
@@ -5086,6 +5089,8 @@ class SQLCompiler(Compiled):
         assert insert_crud_params is not None
 
         escaped_bind_names: Mapping[str, str]
+        expand_pos_lower_index = expand_pos_upper_index = 0
+
         if not self.positional:
             if self.escaped_bind_names:
                 escaped_bind_names = self.escaped_bind_names
@@ -5124,6 +5129,31 @@ class SQLCompiler(Compiled):
             keys_to_replace = set()
             base_parameters = {}
             executemany_values_w_comma = f"({imv.single_values_expr}), "
+
+            all_names_we_will_expand: Set[str] = set()
+            for elem in imv.insert_crud_params:
+                all_names_we_will_expand.update(elem[3])
+
+            # get the start and end position in a particular list
+            # of parameters where we will be doing the "expanding".
+            # statements can have params on either side or both sides,
+            # given RETURNING and CTEs
+            if all_names_we_will_expand:
+                positiontup = self.positiontup
+                assert positiontup is not None
+
+                all_expand_positions = {
+                    idx
+                    for idx, name in enumerate(positiontup)
+                    if name in all_names_we_will_expand
+                }
+                expand_pos_lower_index = min(all_expand_positions)
+                expand_pos_upper_index = max(all_expand_positions) + 1
+                assert (
+                    len(all_expand_positions)
+                    == expand_pos_upper_index - expand_pos_lower_index
+                )
+
             if self._numeric_binds:
                 escaped = re.escape(self._numeric_binds_identifier_char)
                 executemany_values_w_comma = re.sub(
@@ -5149,52 +5179,61 @@ class SQLCompiler(Compiled):
 
             replaced_parameters: Any
             if self.positional:
-                # the assumption here is that any parameters that are not
-                # in the VALUES clause are expected to be parameterized
-                # expressions in the RETURNING (or maybe ON CONFLICT) clause.
-                # So based on
-                # which sequence comes first in the compiler's INSERT
-                # statement tells us where to expand the parameters.
-
-                # otherwise we probably shouldn't be doing insertmanyvalues
-                # on the statement.
-
                 num_ins_params = imv.num_positional_params_counted
 
                 batch_iterator: Iterable[Tuple[Any, ...]]
                 if num_ins_params == len(batch[0]):
-                    extra_params = ()
+                    extra_params_left = extra_params_right = ()
                     batch_iterator = batch
-                elif self.returning_precedes_values or self._numeric_binds:
-                    extra_params = batch[0][:-num_ins_params]
-                    batch_iterator = (b[-num_ins_params:] for b in batch)
                 else:
-                    extra_params = batch[0][num_ins_params:]
-                    batch_iterator = (b[:num_ins_params] for b in batch)
+                    extra_params_left = batch[0][:expand_pos_lower_index]
+                    extra_params_right = batch[0][expand_pos_upper_index:]
+                    batch_iterator = (
+                        b[expand_pos_lower_index:expand_pos_upper_index]
+                        for b in batch
+                    )
+
+                expanded_values_string = (
+                    executemany_values_w_comma * len(batch)
+                )[:-2]
 
-                values_string = (executemany_values_w_comma * len(batch))[:-2]
                 if self._numeric_binds and num_ins_params > 0:
+                    # numeric will always number the parameters inside of
+                    # VALUES (and thus order self.positiontup) to be higher
+                    # than non-VALUES parameters, no matter where in the
+                    # statement those non-VALUES parameters appear (this is
+                    # ensured in _process_numeric by numbering first all
+                    # params that are not in _values_bindparam)
+                    # therefore all extra params are always
+                    # on the left side and numbered lower than the VALUES
+                    # parameters
+                    assert not extra_params_right
+
+                    start = expand_pos_lower_index + 1
+                    end = num_ins_params * (len(batch)) + start
+
                     # need to format here, since statement may contain
                     # unescaped %, while values_string contains just (%s, %s)
-                    start = len(extra_params) + 1
-                    end = num_ins_params * len(batch) + start
                     positions = tuple(
                         f"{self._numeric_binds_identifier_char}{i}"
                         for i in range(start, end)
                     )
-                    values_string = values_string % positions
+                    expanded_values_string = expanded_values_string % positions
 
                 replaced_statement = statement.replace(
-                    "__EXECMANY_TOKEN__", values_string
+                    "__EXECMANY_TOKEN__", expanded_values_string
                 )
 
                 replaced_parameters = tuple(
                     itertools.chain.from_iterable(batch_iterator)
                 )
-                if self.returning_precedes_values or self._numeric_binds:
-                    replaced_parameters = extra_params + replaced_parameters
-                else:
-                    replaced_parameters = replaced_parameters + extra_params
+
+                replaced_parameters = (
+                    extra_params_left
+                    + replaced_parameters
+                    + extra_params_right
+                )
+
             else:
                 replaced_values_clauses = []
                 replaced_parameters = base_parameters.copy()
@@ -5224,7 +5263,7 @@ class SQLCompiler(Compiled):
             )
             batchnum += 1
 
-    def visit_insert(self, insert_stmt, **kw):
+    def visit_insert(self, insert_stmt, visited_bindparam=None, **kw):
 
         compile_state = insert_stmt._compile_state_factory(
             insert_stmt, self, **kw
@@ -5250,6 +5289,9 @@ class SQLCompiler(Compiled):
 
         counted_bindparam = 0
 
+        # reset any incoming "visited_bindparam" collection
+        visited_bindparam = None
+
         # for positional, insertmanyvalues needs to know how many
         # bound parameters are in the VALUES sequence; there's no simple
         # rule because default expressions etc. can have zero or more
@@ -5257,21 +5299,30 @@ class SQLCompiler(Compiled):
         # this very simplistic "count after" works and is
         # likely the least amount of callcounts, though looks clumsy
         if self.positional:
-            self._visited_bindparam = []
+            # if we are inside a CTE, don't count parameters
+            # here since they wont be for insertmanyvalues. keep
+            # visited_bindparam at None so no counting happens.
+            # see #9173
+            has_visiting_cte = "visiting_cte" in kw
+            if not has_visiting_cte:
+                visited_bindparam = []
 
         crud_params_struct = crud._get_crud_params(
-            self, insert_stmt, compile_state, toplevel, **kw
+            self,
+            insert_stmt,
+            compile_state,
+            toplevel,
+            visited_bindparam=visited_bindparam,
+            **kw,
         )
 
-        if self.positional:
-            assert self._visited_bindparam is not None
-            counted_bindparam = len(self._visited_bindparam)
+        if self.positional and visited_bindparam is not None:
+            counted_bindparam = len(visited_bindparam)
             if self._numeric_binds:
                 if self._values_bindparam is not None:
-                    self._values_bindparam += self._visited_bindparam
+                    self._values_bindparam += visited_bindparam
                 else:
-                    self._values_bindparam = self._visited_bindparam
-            self._visited_bindparam = None
+                    self._values_bindparam = visited_bindparam
 
         crud_params_single = crud_params_struct.single_params
 
index 5017afa78ef913e2a91ce09c036850c54e78b107..04b62d1ffd9daada4147dcc5d501af039e292aed 100644 (file)
@@ -150,6 +150,17 @@ def _get_crud_params(
     compiler.update_prefetch = []
     compiler.implicit_returning = []
 
+    visiting_cte = kw.get("visiting_cte", None)
+    if visiting_cte is not None:
+        # for insert -> CTE -> insert, don't populate an incoming
+        # _crud_accumulate_bind_names collection; the INSERT we process here
+        # will not be inline within the VALUES of the enclosing INSERT as the
+        # CTE is placed on the outside.  See issue #9173
+        kw.pop("accumulate_bind_names", None)
+    assert (
+        "accumulate_bind_names" not in kw
+    ), "Don't know how to handle insert within insert without a CTE"
+
     # getters - these are normally just column.key,
     # but in the case of mysql multi-table update, the rules for
     # .key must conditionally take tablename into account
index 502104daeadc6c2a6a6f3e9053889e28fb5b6f5a..4ba4eddfeb0b9e277f9266f04aab63a5c5923aaa 100644 (file)
@@ -1317,6 +1317,72 @@ class CTETest(fixtures.TestBase, AssertsCompiledSQL):
             dialect=dialect,
         )
 
+    @testing.combinations(
+        ("default_enhanced",),
+        ("postgresql",),
+        ("postgresql+asyncpg",),
+    )
+    def test_insert_w_cte_in_scalar_subquery(self, dialect):
+        """test #9173"""
+
+        customer = table(
+            "customer",
+            column("id"),
+            column("name"),
+        )
+        order = table(
+            "order",
+            column("id"),
+            column("price"),
+            column("customer_id"),
+        )
+
+        inst = (
+            customer.insert()
+            .values(name="John")
+            .returning(customer.c.id)
+            .cte("inst")
+        )
+
+        stmt = (
+            order.insert()
+            .values(
+                price=1,
+                customer_id=select(inst.c.id).scalar_subquery(),
+            )
+            .add_cte(inst)
+        )
+
+        if dialect == "default_enhanced":
+            self.assert_compile(
+                stmt,
+                "WITH inst AS (INSERT INTO customer (name) VALUES (:param_1) "
+                'RETURNING customer.id) INSERT INTO "order" '
+                "(price, customer_id) VALUES "
+                "(:price, (SELECT inst.id FROM inst))",
+                dialect=dialect,
+            )
+        elif dialect == "postgresql":
+            self.assert_compile(
+                stmt,
+                "WITH inst AS (INSERT INTO customer (name) "
+                "VALUES (%(param_1)s) "
+                'RETURNING customer.id) INSERT INTO "order" '
+                "(price, customer_id) "
+                "VALUES (%(price)s, (SELECT inst.id FROM inst))",
+                dialect=dialect,
+            )
+        elif dialect == "postgresql+asyncpg":
+            self.assert_compile(
+                stmt,
+                "WITH inst AS (INSERT INTO customer (name) VALUES ($2) "
+                'RETURNING customer.id) INSERT INTO "order" '
+                "(price, customer_id) VALUES ($1, (SELECT inst.id FROM inst))",
+                dialect=dialect,
+            )
+        else:
+            assert False
+
     @testing.combinations(
         ("default_enhanced",),
         ("postgresql",),
index d9dac75b332816ec0161df1b5dcaec7856be5e36..3b5a1856cdec09e682f988e96829738a28aac3cc 100644 (file)
@@ -23,6 +23,7 @@ from sqlalchemy.testing import expect_raises_message
 from sqlalchemy.testing import fixtures
 from sqlalchemy.testing import is_
 from sqlalchemy.testing import mock
+from sqlalchemy.testing import provision
 from sqlalchemy.testing.provision import normalize_sequence
 from sqlalchemy.testing.schema import Column
 from sqlalchemy.testing.schema import Table
@@ -825,6 +826,119 @@ class InsertManyValuesTest(fixtures.RemovesEvents, fixtures.TablesTest):
 
         eq_(result.inserted_primary_key_rows, [(1,), (2,), (3,)])
 
+    @testing.requires.ctes_on_dml
+    @testing.variation("add_expr_returning", [True, False])
+    def test_insert_w_bindparam_in_nested_insert(
+        self, connection, add_expr_returning
+    ):
+        """test related to #9173"""
+
+        data, extra_table = self.tables("data", "extra_table")
+
+        inst = (
+            extra_table.insert()
+            .values(x_value="x", y_value="y")
+            .returning(extra_table.c.id)
+            .cte("inst")
+        )
+
+        stmt = (
+            data.insert()
+            .values(x="the x", z=select(inst.c.id).scalar_subquery())
+            .add_cte(inst)
+        )
+
+        if add_expr_returning:
+            stmt = stmt.returning(data.c.id, data.c.y + " returned y")
+        else:
+            stmt = stmt.returning(data.c.id)
+
+        result = connection.execute(
+            stmt,
+            [
+                {"y": "y1"},
+                {"y": "y2"},
+                {"y": "y3"},
+            ],
+        )
+
+        result_rows = result.all()
+
+        ids = [row[0] for row in result_rows]
+
+        extra_row = connection.execute(
+            select(extra_table).order_by(extra_table.c.id)
+        ).one()
+        extra_row_id = extra_row[0]
+        eq_(extra_row, (extra_row_id, "x", "y"))
+        eq_(
+            connection.execute(select(data).order_by(data.c.id)).all(),
+            [
+                (ids[0], "the x", "y1", extra_row_id),
+                (ids[1], "the x", "y2", extra_row_id),
+                (ids[2], "the x", "y3", extra_row_id),
+            ],
+        )
+
+    @testing.requires.provisioned_upsert
+    def test_upsert_w_returning(self, connection):
+        """test cases that will execise SQL similar to that of
+        test/orm/dml/test_bulk_statements.py
+
+        """
+
+        data = self.tables.data
+
+        initial_data = [
+            {"x": "x1", "y": "y1", "z": 4},
+            {"x": "x2", "y": "y2", "z": 8},
+        ]
+        ids = connection.scalars(
+            data.insert().returning(data.c.id), initial_data
+        ).all()
+
+        upsert_data = [
+            {
+                "id": ids[0],
+                "x": "x1",
+                "y": "y1",
+            },
+            {
+                "id": 32,
+                "x": "x19",
+                "y": "y7",
+            },
+            {
+                "id": ids[1],
+                "x": "x5",
+                "y": "y6",
+            },
+            {
+                "id": 28,
+                "x": "x9",
+                "y": "y15",
+            },
+        ]
+
+        stmt = provision.upsert(
+            config,
+            data,
+            (data,),
+            lambda inserted: {"x": inserted.x + " upserted"},
+        )
+
+        result = connection.execute(stmt, upsert_data)
+
+        eq_(
+            result.all(),
+            [
+                (ids[0], "x1 upserted", "y1", 4),
+                (32, "x19", "y7", 5),
+                (ids[1], "x5 upserted", "y2", 8),
+                (28, "x9", "y15", 5),
+            ],
+        )
+
     @testing.combinations(True, False, argnames="use_returning")
     @testing.combinations(1, 2, argnames="num_embedded_params")
     @testing.combinations(True, False, argnames="use_whereclause")
@@ -835,7 +949,11 @@ class InsertManyValuesTest(fixtures.RemovesEvents, fixtures.TablesTest):
     def test_insert_w_bindparam_in_subq(
         self, connection, use_returning, num_embedded_params, use_whereclause
     ):
-        """test #8639"""
+        """test #8639
+
+        see also test_insert_w_bindparam_in_nested_insert
+
+        """
 
         t = self.tables.data
         extra = self.tables.extra_table