]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
render INSERT/UPDATE column expressions up front; pass state
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 8 Aug 2020 17:03:17 +0000 (13:03 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 8 Aug 2020 17:34:27 +0000 (13:34 -0400)
Fixes related to rendering of complex UPDATE DML
which was not correctly preserving positional parameter
order in conjunction with DML features that are only known
to work on the PostgreSQL database.    Both pg8000
and asyncpg use positional parameters which is why these
issues are suddenly apparent.

crud.py now takes on the task of rendering the column
expressions for SET or VALUES so that for the very unusual
case that the column expression is a compound expression
that includes a bound parameter (namely an array index),
the bound parameter order is preserved.

Additionally, crud.py passes through the positional_names
keyword argument into bindparam_string() which is necessary
when CTEs are being rendered, as PG supports complex
CTE / INSERT / UPDATE scenarios.

Change-Id: I7f03920500e19b721636b84594de78a5bfdcbc82

lib/sqlalchemy/sql/compiler.py
lib/sqlalchemy/sql/crud.py
lib/sqlalchemy/sql/dml.py
test/sql/test_update.py

index 8e273f67c47fdbadafee42052a3de3719dcd4824..542bf58ac503e214a59d94b158a3da1a65451035 100644 (file)
@@ -3286,7 +3286,7 @@ class SQLCompiler(Compiled):
 
         if crud_params_single or not supports_default_values:
             text += " (%s)" % ", ".join(
-                [preparer.format_column(c[0]) for c in crud_params_single]
+                [expr for c, expr, value in crud_params_single]
             )
 
         if self.returning or insert_stmt._returning:
@@ -3311,12 +3311,15 @@ class SQLCompiler(Compiled):
         elif compile_state._has_multi_parameters:
             text += " VALUES %s" % (
                 ", ".join(
-                    "(%s)" % (", ".join(c[1] for c in crud_param_set))
+                    "(%s)"
+                    % (", ".join(value for c, expr, value in crud_param_set))
                     for crud_param_set in crud_params
                 )
             )
         else:
-            insert_single_values_expr = ", ".join([c[1] for c in crud_params])
+            insert_single_values_expr = ", ".join(
+                [value for c, expr, value in crud_params]
+            )
             text += " VALUES (%s)" % insert_single_values_expr
             if toplevel:
                 self.insert_single_values_expr = insert_single_values_expr
@@ -3424,15 +3427,7 @@ class SQLCompiler(Compiled):
         text += table_text
 
         text += " SET "
-        include_table = (
-            is_multitable and self.render_table_with_column_in_update_from
-        )
-        text += ", ".join(
-            c[0]._compiler_dispatch(self, include_table=include_table)
-            + "="
-            + c[1]
-            for c in crud_params
-        )
+        text += ", ".join(expr + "=" + value for c, expr, value in crud_params)
 
         if self.returning or update_stmt._returning:
             if self.returning_precedes_values:
index 85112f8506aaad2a5a46cfdcfe0cdbcc6ac683c0..7d0616da71f28733ccef4ec4ba1e35cb208343e4 100644 (file)
@@ -65,7 +65,11 @@ def _get_crud_params(compiler, stmt, compile_state, **kw):
     # compiled params - return binds for all columns
     if compiler.column_keys is None and compile_state._no_parameters:
         return [
-            (c, _create_bind_param(compiler, c, None, required=True))
+            (
+                c,
+                compiler.preparer.format_column(c),
+                _create_bind_param(compiler, c, None, required=True),
+            )
             for c in stmt.table.columns
         ]
 
@@ -90,18 +94,20 @@ def _get_crud_params(compiler, stmt, compile_state, **kw):
 
     if stmt_parameters is not None:
         _get_stmt_parameters_params(
-            compiler, parameters, stmt_parameters, _column_as_key, values, kw
+            compiler,
+            compile_state,
+            parameters,
+            stmt_parameters,
+            _column_as_key,
+            values,
+            kw,
         )
 
     check_columns = {}
 
     # special logic that only occurs for multi-table UPDATE
     # statements
-    if (
-        compile_state.isupdate
-        and compile_state._extra_froms
-        and stmt_parameters
-    ):
+    if compile_state.isupdate and compile_state.is_multitable:
         _get_multitable_params(
             compiler,
             stmt,
@@ -162,7 +168,13 @@ def _get_crud_params(compiler, stmt, compile_state, **kw):
         # into INSERT (firstcol) VALUES (DEFAULT) which can be turned
         # into an in-place multi values.  This supports
         # insert_executemany_returning mode :)
-        values = [(stmt.table.columns[0], "DEFAULT")]
+        values = [
+            (
+                stmt.table.columns[0],
+                compiler.preparer.format_column(stmt.table.columns[0]),
+                "DEFAULT",
+            )
+        ]
 
     return values
 
@@ -286,7 +298,7 @@ def _scan_insert_from_select_cols(
         col_key = _getattr_col_key(c)
         if col_key in parameters and col_key not in check_columns:
             parameters.pop(col_key)
-            values.append((c, None))
+            values.append((c, compiler.preparer.format_column(c), None))
         else:
             _append_param_insert_select_hasdefault(
                 compiler, stmt, c, add_select_cols, kw
@@ -297,7 +309,7 @@ def _scan_insert_from_select_cols(
         compiler._insert_from_select = compiler._insert_from_select._generate()
         compiler._insert_from_select._raw_columns = tuple(
             compiler._insert_from_select._raw_columns
-        ) + tuple(expr for col, expr in add_select_cols)
+        ) + tuple(expr for col, col_expr, expr in add_select_cols)
 
 
 def _scan_cols(
@@ -390,7 +402,13 @@ def _scan_cols(
 
         elif compile_state.isupdate:
             _append_param_update(
-                compiler, stmt, c, implicit_return_defaults, values, kw
+                compiler,
+                compile_state,
+                stmt,
+                c,
+                implicit_return_defaults,
+                values,
+                kw,
             )
 
 
@@ -410,6 +428,10 @@ def _append_param_parameter(
 
     value = parameters.pop(col_key)
 
+    col_value = compiler.preparer.format_column(
+        c, use_table=compile_state.include_table_with_column_exprs
+    )
+
     if coercions._is_literal(value):
         value = _create_bind_param(
             compiler,
@@ -446,7 +468,7 @@ def _append_param_parameter(
             if not c.primary_key:
                 compiler.postfetch.append(c)
             value = compiler.process(value.self_group(), **kw)
-    values.append((c, value))
+    values.append((c, col_value, value))
 
 
 def _append_param_insert_pk_returning(compiler, stmt, c, values, kw):
@@ -472,16 +494,31 @@ def _append_param_insert_pk_returning(compiler, stmt, c, values, kw):
                 not c.default.optional
                 or not compiler.dialect.sequences_optional
             ):
-                proc = compiler.process(c.default, **kw)
-                values.append((c, proc))
+                values.append(
+                    (
+                        c,
+                        compiler.preparer.format_column(c),
+                        compiler.process(c.default, **kw),
+                    )
+                )
             compiler.returning.append(c)
         elif c.default.is_clause_element:
             values.append(
-                (c, compiler.process(c.default.arg.self_group(), **kw))
+                (
+                    c,
+                    compiler.preparer.format_column(c),
+                    compiler.process(c.default.arg.self_group(), **kw),
+                )
             )
             compiler.returning.append(c)
         else:
-            values.append((c, _create_insert_prefetch_bind_param(compiler, c)))
+            values.append(
+                (
+                    c,
+                    compiler.preparer.format_column(c),
+                    _create_insert_prefetch_bind_param(compiler, c, **kw),
+                )
+            )
     elif c is stmt.table._autoincrement_column or c.server_default is not None:
         compiler.returning.append(c)
     elif not c.nullable:
@@ -490,14 +527,22 @@ def _append_param_insert_pk_returning(compiler, stmt, c, values, kw):
         _warn_pk_with_no_anticipated_value(c)
 
 
-def _create_insert_prefetch_bind_param(compiler, c, process=True, name=None):
-    param = _create_bind_param(compiler, c, None, process=process, name=name)
+def _create_insert_prefetch_bind_param(
+    compiler, c, process=True, name=None, **kw
+):
+    param = _create_bind_param(
+        compiler, c, None, process=process, name=name, **kw
+    )
     compiler.insert_prefetch.append(c)
     return param
 
 
-def _create_update_prefetch_bind_param(compiler, c, process=True, name=None):
-    param = _create_bind_param(compiler, c, None, process=process, name=name)
+def _create_update_prefetch_bind_param(
+    compiler, c, process=True, name=None, **kw
+):
+    param = _create_bind_param(
+        compiler, c, None, process=process, name=name, **kw
+    )
     compiler.update_prefetch.append(c)
     return param
 
@@ -539,9 +584,9 @@ def _process_multiparam_default_bind(compiler, stmt, c, index, kw):
     else:
         col = _multiparam_column(c, index)
         if isinstance(stmt, dml.Insert):
-            return _create_insert_prefetch_bind_param(compiler, col)
+            return _create_insert_prefetch_bind_param(compiler, col, **kw)
         else:
-            return _create_update_prefetch_bind_param(compiler, col)
+            return _create_update_prefetch_bind_param(compiler, col, **kw)
 
 
 def _append_param_insert_pk(compiler, stmt, c, values, kw):
@@ -582,7 +627,13 @@ def _append_param_insert_pk(compiler, stmt, c, values, kw):
             or compiler.dialect.preexecute_autoincrement_sequences
         )
     ):
-        values.append((c, _create_insert_prefetch_bind_param(compiler, c)))
+        values.append(
+            (
+                c,
+                compiler.preparer.format_column(c),
+                _create_insert_prefetch_bind_param(compiler, c, **kw),
+            )
+        )
     elif c.default is None and c.server_default is None and not c.nullable:
         # no .default, no .server_default, not autoincrement, we have
         # no indication this primary key column will have any value
@@ -597,15 +648,25 @@ def _append_param_insert_hasdefault(
         if compiler.dialect.supports_sequences and (
             not c.default.optional or not compiler.dialect.sequences_optional
         ):
-            proc = compiler.process(c.default, **kw)
-            values.append((c, proc))
+            values.append(
+                (
+                    c,
+                    compiler.preparer.format_column(c),
+                    compiler.process(c.default, **kw),
+                )
+            )
             if implicit_return_defaults and c in implicit_return_defaults:
                 compiler.returning.append(c)
             elif not c.primary_key:
                 compiler.postfetch.append(c)
     elif c.default.is_clause_element:
-        proc = compiler.process(c.default.arg.self_group(), **kw)
-        values.append((c, proc))
+        values.append(
+            (
+                c,
+                compiler.preparer.format_column(c),
+                compiler.process(c.default.arg.self_group(), **kw),
+            )
+        )
 
         if implicit_return_defaults and c in implicit_return_defaults:
             compiler.returning.append(c)
@@ -613,7 +674,13 @@ def _append_param_insert_hasdefault(
             # don't add primary key column to postfetch
             compiler.postfetch.append(c)
     else:
-        values.append((c, _create_insert_prefetch_bind_param(compiler, c)))
+        values.append(
+            (
+                c,
+                compiler.preparer.format_column(c),
+                _create_insert_prefetch_bind_param(compiler, c, **kw),
+            )
+        )
 
 
 def _append_param_insert_select_hasdefault(compiler, stmt, c, values, kw):
@@ -622,32 +689,55 @@ def _append_param_insert_select_hasdefault(compiler, stmt, c, values, kw):
         if compiler.dialect.supports_sequences and (
             not c.default.optional or not compiler.dialect.sequences_optional
         ):
-            proc = c.default
-            values.append((c, proc.next_value()))
+            values.append(
+                (c, compiler.preparer.format_column(c), c.default.next_value())
+            )
     elif c.default.is_clause_element:
-        proc = c.default.arg.self_group()
-        values.append((c, proc))
+        values.append(
+            (c, compiler.preparer.format_column(c), c.default.arg.self_group())
+        )
     else:
         values.append(
-            (c, _create_insert_prefetch_bind_param(compiler, c, process=False))
+            (
+                c,
+                compiler.preparer.format_column(c),
+                _create_insert_prefetch_bind_param(
+                    compiler, c, process=False, **kw
+                ),
+            )
         )
 
 
 def _append_param_update(
-    compiler, stmt, c, implicit_return_defaults, values, kw
+    compiler, compile_state, stmt, c, implicit_return_defaults, values, kw
 ):
 
+    include_table = compile_state.include_table_with_column_exprs
     if c.onupdate is not None and not c.onupdate.is_sequence:
         if c.onupdate.is_clause_element:
             values.append(
-                (c, compiler.process(c.onupdate.arg.self_group(), **kw))
+                (
+                    c,
+                    compiler.preparer.format_column(
+                        c, use_table=include_table,
+                    ),
+                    compiler.process(c.onupdate.arg.self_group(), **kw),
+                )
             )
             if implicit_return_defaults and c in implicit_return_defaults:
                 compiler.returning.append(c)
             else:
                 compiler.postfetch.append(c)
         else:
-            values.append((c, _create_update_prefetch_bind_param(compiler, c)))
+            values.append(
+                (
+                    c,
+                    compiler.preparer.format_column(
+                        c, use_table=include_table,
+                    ),
+                    _create_update_prefetch_bind_param(compiler, c, **kw),
+                )
+            )
     elif c.server_onupdate is not None:
         if implicit_return_defaults and c in implicit_return_defaults:
             compiler.returning.append(c)
@@ -676,6 +766,9 @@ def _get_multitable_params(
         (coercions.expect(roles.DMLColumnRole, c), param)
         for c, param in stmt_parameters.items()
     )
+
+    include_table = compile_state.include_table_with_column_exprs
+
     affected_tables = set()
     for t in compile_state._extra_froms:
         for c in t.c:
@@ -683,6 +776,8 @@ def _get_multitable_params(
                 affected_tables.add(t)
                 check_columns[_getattr_col_key(c)] = c
                 value = normalized_params[c]
+
+                col_value = compiler.process(c, include_table=include_table)
                 if coercions._is_literal(value):
                     value = _create_bind_param(
                         compiler,
@@ -699,7 +794,7 @@ def _get_multitable_params(
                 else:
                     compiler.postfetch.append(c)
                     value = compiler.process(value.self_group(), **kw)
-                values.append((c, value))
+                values.append((c, col_value, value))
     # determine tables which are actually to be updated - process onupdate
     # and server_onupdate for these
     for t in affected_tables:
@@ -711,6 +806,7 @@ def _get_multitable_params(
                     values.append(
                         (
                             c,
+                            compiler.process(c, include_table=include_table),
                             compiler.process(
                                 c.onupdate.arg.self_group(), **kw
                             ),
@@ -721,8 +817,9 @@ def _get_multitable_params(
                     values.append(
                         (
                             c,
+                            compiler.process(c, include_table=include_table),
                             _create_update_prefetch_bind_param(
-                                compiler, c, name=_col_bind_name(c)
+                                compiler, c, name=_col_bind_name(c), **kw
                             ),
                         )
                     )
@@ -736,7 +833,7 @@ def _extend_values_for_multiparams(compiler, stmt, compile_state, values, kw):
 
     for i, row in enumerate(compile_state._multi_parameters[1:]):
         extension = []
-        for (col, param) in values_0:
+        for (col, col_expr, param) in values_0:
             if col in row or col.key in row:
                 key = col if col in row else col.key
 
@@ -755,7 +852,7 @@ def _extend_values_for_multiparams(compiler, stmt, compile_state, values, kw):
                     compiler, stmt, col, i, kw
                 )
 
-            extension.append((col, new_param))
+            extension.append((col, col_expr, new_param))
 
         values.append(extension)
 
@@ -763,8 +860,15 @@ def _extend_values_for_multiparams(compiler, stmt, compile_state, values, kw):
 
 
 def _get_stmt_parameters_params(
-    compiler, parameters, stmt_parameters, _column_as_key, values, kw
+    compiler,
+    compile_state,
+    parameters,
+    stmt_parameters,
+    _column_as_key,
+    values,
+    kw,
 ):
+
     for k, v in stmt_parameters.items():
         colkey = _column_as_key(k)
         if colkey is not None:
@@ -773,6 +877,11 @@ def _get_stmt_parameters_params(
             # a non-Column expression on the left side;
             # add it to values() in an "as-is" state,
             # coercing right side to bound param
+
+            col_expr = compiler.process(
+                k, include_table=compile_state.include_table_with_column_exprs
+            )
+
             if coercions._is_literal(v):
                 v = compiler.process(
                     elements.BindParameter(None, v, type_=k.type), **kw
@@ -780,7 +889,7 @@ def _get_stmt_parameters_params(
             else:
                 v = compiler.process(v.self_group(), **kw)
 
-            values.append((k, v))
+            values.append((k, col_expr, v))
 
 
 def _get_returning_modifiers(compiler, stmt, compile_state):
index 21476c1f91e8e928cfef25ba0d30df20c7193a97..9308658984ba784308ea082135b09d52982f2183 100644 (file)
@@ -133,6 +133,8 @@ class DMLState(CompileState):
 class InsertDMLState(DMLState):
     isinsert = True
 
+    include_table_with_column_exprs = False
+
     def __init__(self, statement, compiler, **kw):
         self.statement = statement
 
@@ -149,6 +151,8 @@ class InsertDMLState(DMLState):
 class UpdateDMLState(DMLState):
     isupdate = True
 
+    include_table_with_column_exprs = False
+
     def __init__(self, statement, compiler, **kw):
         self.statement = statement
         self.isupdate = True
@@ -159,7 +163,11 @@ class UpdateDMLState(DMLState):
             self._process_values(statement)
         elif statement._multi_values:
             self._process_multi_values(statement)
-        self._extra_froms = self._make_extra_froms(statement)
+        self._extra_froms = ef = self._make_extra_froms(statement)
+        self.is_multitable = mt = ef and self._dict_parameters
+        self.include_table_with_column_exprs = (
+            mt and compiler.render_table_with_column_in_update_from
+        )
 
 
 @CompileState.plugin_for("default", "delete")
index 18e9da654fb4e2dab9febb47b55214d9588c0c7a..5db5fed11cf242ac115629266293f968044ab7c4 100644 (file)
@@ -840,7 +840,7 @@ class UpdateTest(_UpdateFromTestBase, fixtures.TablesTest, AssertsCompiledSQL):
             dialect=mysql.dialect(),
         )
 
-    def test_update_to_expression(self):
+    def test_update_to_expression_one(self):
         """test update from an expression.
 
         this logic is triggered currently by a left side that doesn't
@@ -856,6 +856,72 @@ class UpdateTest(_UpdateFromTestBase, fixtures.TablesTest, AssertsCompiledSQL):
             "UPDATE mytable SET foo(myid)=:param_1",
         )
 
+    def test_update_to_expression_two(self):
+        """test update from an expression.
+
+        this logic is triggered currently by a left side that doesn't
+        have a key.  The current supported use case is updating the index
+        of a PostgreSQL ARRAY type.
+
+        """
+
+        from sqlalchemy import ARRAY
+
+        t = table(
+            "foo",
+            column("data1", ARRAY(Integer)),
+            column("data2", ARRAY(Integer)),
+        )
+
+        stmt = t.update().values({t.c.data1[5]: 7, t.c.data2[10]: 18})
+        dialect = default.StrCompileDialect()
+        dialect.paramstyle = "qmark"
+        dialect.positional = True
+        self.assert_compile(
+            stmt,
+            "UPDATE foo SET data1[?]=?, data2[?]=?",
+            dialect=dialect,
+            checkpositional=(5, 7, 10, 18),
+        )
+
+    def test_update_to_expression_three(self):
+        # this test is from test_defaults but exercises a particular
+        # parameter ordering issue
+        metadata = MetaData()
+
+        q = Table(
+            "q",
+            metadata,
+            Column("x", Integer, default=2),
+            Column("y", Integer, onupdate=5),
+            Column("z", Integer),
+        )
+
+        p = Table(
+            "p",
+            metadata,
+            Column("s", Integer),
+            Column("t", Integer),
+            Column("u", Integer, onupdate=1),
+        )
+
+        cte = (
+            q.update().where(q.c.z == 1).values(x=7).returning(q.c.z).cte("c")
+        )
+        stmt = select([p.c.s, cte.c.z]).where(p.c.s == cte.c.z)
+
+        dialect = default.StrCompileDialect()
+        dialect.paramstyle = "qmark"
+        dialect.positional = True
+
+        self.assert_compile(
+            stmt,
+            "WITH c AS (UPDATE q SET x=?, y=? WHERE q.z = ? RETURNING q.z) "
+            "SELECT p.s, c.z FROM p, c WHERE p.s = c.z",
+            checkpositional=(7, None, 1),
+            dialect=dialect,
+        )
+
     def test_update_bound_ordering(self):
         """test that bound parameters between the UPDATE and FROM clauses
         order correctly in different SQL compilation scenarios.