]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
establish consistency for RETURNING column labels
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 7 Nov 2022 23:40:03 +0000 (18:40 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 11 Nov 2022 19:48:49 +0000 (14:48 -0500)
For the PostgreSQL and SQL Server dialects only, adjusted the compiler so
that when rendering column expressions in the RETURNING clause, the "non
anon" label that's used in SELECT statements is suggested for SQL
expression elements that generate a label; the primary example is a SQL
function that may be emitting as part of the column's type, where the label
name should match the column's name by default. This restores a not-well
defined behavior that had changed in version 1.4.21 due to :ticket:`6718`,
:ticket:`6710`. The Oracle dialect has a different RETURNING implementation
and was not affected by this issue. Version 2.0 features an across the
board change for its widely expanded support of RETURNING on other
backends.

Fixed issue in the Oracle dialect where an INSERT statement that used
``insert(some_table).values(...).returning(some_table)`` against a full
:class:`.Table` object at once would fail to execute, raising an exception.

Fixes: #8770
Change-Id: I2ab078a214a778ffe1720dbd864ae4c105a0691d
(cherry picked from commit c8a7b67181d31634355150fc0379ec0e780ff728)

doc/build/changelog/unreleased_14/8770.rst [new file with mode: 0644]
lib/sqlalchemy/dialects/mssql/base.py
lib/sqlalchemy/dialects/oracle/cx_oracle.py
lib/sqlalchemy/dialects/postgresql/base.py
lib/sqlalchemy/sql/compiler.py
lib/sqlalchemy/sql/selectable.py
test/dialect/mssql/test_compiler.py
test/dialect/oracle/test_compiler.py
test/dialect/postgresql/test_compiler.py
test/sql/test_labels.py
test/sql/test_returning.py

diff --git a/doc/build/changelog/unreleased_14/8770.rst b/doc/build/changelog/unreleased_14/8770.rst
new file mode 100644 (file)
index 0000000..8968b03
--- /dev/null
@@ -0,0 +1,23 @@
+.. change::
+    :tags: bug, postgresql, mssql
+    :tickets: 8770
+
+    For the PostgreSQL and SQL Server dialects only, adjusted the compiler so
+    that when rendering column expressions in the RETURNING clause, the "non
+    anon" label that's used in SELECT statements is suggested for SQL
+    expression elements that generate a label; the primary example is a SQL
+    function that may be emitting as part of the column's type, where the label
+    name should match the column's name by default. This restores a not-well
+    defined behavior that had changed in version 1.4.21 due to :ticket:`6718`,
+    :ticket:`6710`. The Oracle dialect has a different RETURNING implementation
+    and was not affected by this issue. Version 2.0 features an across the
+    board change for its widely expanded support of RETURNING on other
+    backends.
+
+
+.. change::
+    :tags: bug, oracle
+
+    Fixed issue in the Oracle dialect where an INSERT statement that used
+    ``insert(some_table).values(...).returning(some_table)`` against a full
+    :class:`.Table` object at once would fail to execute, raising an exception.
index 0509413062fd466344aeec091eb59d0a443d1e4b..ea9c90a51ec556d86e45ad5b19e18c6489ccf4e6 100644 (file)
@@ -2154,6 +2154,7 @@ class MSSQLCompiler(compiler.SQLCompiler):
                 stmt,
                 adapter.traverse(c),
                 {"result_map_targets": (c,)},
+                fallback_label_name=c._non_anon_label,
             )
             for c in expression._select_iterables(returning_cols)
         ]
index 90dabc83b932c29e1d594accc793b66885ff3884..fe18d1310b02f0888a754fa1cfd44a43878e3cbb 100644 (file)
@@ -468,6 +468,7 @@ from ... import processors
 from ... import types as sqltypes
 from ... import util
 from ...engine import cursor as _cursor
+from ...sql import expression
 from ...util import compat
 
 
@@ -887,11 +888,12 @@ class OracleExecutionContext_cx_oracle(OracleExecutionContext):
                 self.cursor,
                 [
                     (getattr(col, "name", col._anon_name_label), None)
-                    for col in self.compiled.returning
+                    for col in expression._select_iterables(
+                        self.compiled.returning
+                    )
                 ],
                 initial_buffer=[tuple(returning_params)],
             )
-
             self.cursor_fetch_strategy = fetch_strategy
 
     def create_cursor(self):
index c94c7732545dbd8e6c2291068d82653a541339ed..b980183d007bd2f764888423bfc800759cdcd108 100644 (file)
@@ -2515,7 +2515,9 @@ class PGCompiler(compiler.SQLCompiler):
     def returning_clause(self, stmt, returning_cols):
 
         columns = [
-            self._label_returning_column(stmt, c)
+            self._label_returning_column(
+                stmt, c, fallback_label_name=c._non_anon_label
+            )
             for c in expression._select_iterables(returning_cols)
         ]
 
index 0e441fbec8e6c39a78551fc7885c335342c95c57..a7232f096d64b21b514178765af4a2b993f67e47 100644 (file)
@@ -3047,7 +3047,9 @@ class SQLCompiler(Compiled):
             )
         self._result_columns.append((keyname, name, objects, type_))
 
-    def _label_returning_column(self, stmt, column, column_clause_args=None):
+    def _label_returning_column(
+        self, stmt, column, column_clause_args=None, **kw
+    ):
         """Render a column with necessary labels inside of a RETURNING clause.
 
         This method is provided for individual dialects in place of calling
@@ -3063,6 +3065,7 @@ class SQLCompiler(Compiled):
             True,
             False,
             {} if column_clause_args is None else column_clause_args,
+            **kw
         )
 
     def _label_select_column(
@@ -3127,7 +3130,6 @@ class SQLCompiler(Compiled):
             "_label_select_column is only relevant within "
             "the columns clause of a SELECT or RETURNING"
         )
-
         if isinstance(column, elements.Label):
             if col_expr is not column:
                 result_expr = _CompileLabel(
@@ -4319,7 +4321,9 @@ class StrSQLCompiler(SQLCompiler):
 
     def returning_clause(self, stmt, returning_cols):
         columns = [
-            self._label_select_column(None, c, True, False, {})
+            self._label_select_column(
+                None, c, True, False, {}, fallback_label_name=c._non_anon_label
+            )
             for c in base._select_iterables(returning_cols)
         ]
 
index 95e13f0810d0121e1f37f6c072f9b19582324c30..956f8ae8d8adc018096f6a6d49dcde05d130d3a9 100644 (file)
@@ -6191,7 +6191,7 @@ class Select(
             self = self.set_label_style(LABEL_STYLE_DISAMBIGUATE_ONLY)
         return self
 
-    def _generate_columns_plus_names(self, anon_for_dupe_key):
+    def _generate_columns_plus_names(self, anon_for_dupe_key, cols=None):
         """Generate column names as rendered in a SELECT statement by
         the compiler.
 
@@ -6201,7 +6201,9 @@ class Select(
         _column_naming_convention as well.
 
         """
-        cols = self._all_selected_columns
+
+        if cols is None:
+            cols = self._all_selected_columns
 
         key_naming_convention = SelectState._column_naming_convention(
             self._label_style
index bad5e4e10b6ec07416cc74920b00e9e662c66cf1..d54295b3062eb03505495856fef75eabb9bdc9e4 100644 (file)
@@ -37,6 +37,7 @@ from sqlalchemy.testing import eq_
 from sqlalchemy.testing import fixtures
 from sqlalchemy.testing import is_
 from sqlalchemy.testing.assertions import eq_ignore_whitespace
+from sqlalchemy.types import TypeEngine
 
 tbl = table("t", column("a"))
 
@@ -104,6 +105,34 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
             "Latin1_General_CS_AS_KS_WS_CI ASC",
         )
 
+    @testing.fixture
+    def column_expression_fixture(self):
+        class MyString(TypeEngine):
+            def column_expression(self, column):
+                return func.lower(column)
+
+        return table(
+            "some_table", column("name", String), column("value", MyString)
+        )
+
+    @testing.combinations("columns", "table", argnames="use_columns")
+    def test_plain_returning_column_expression(
+        self, column_expression_fixture, use_columns
+    ):
+        """test #8770"""
+        table1 = column_expression_fixture
+
+        if use_columns == "columns":
+            stmt = insert(table1).returning(table1)
+        else:
+            stmt = insert(table1).returning(table1.c.name, table1.c.value)
+
+        self.assert_compile(
+            stmt,
+            "INSERT INTO some_table (name, value) OUTPUT inserted.name, "
+            "lower(inserted.value) AS value VALUES (:name, :value)",
+        )
+
     def test_join_with_hint(self):
         t1 = table(
             "t1",
index 22ffc888ab088e2f6ba7387c60d2735f35d7db7a..8a8f51df0120dfbc9e9286319eb0264e30be506f 100644 (file)
@@ -8,6 +8,7 @@ from sqlalchemy import ForeignKey
 from sqlalchemy import func
 from sqlalchemy import Identity
 from sqlalchemy import Index
+from sqlalchemy import insert
 from sqlalchemy import Integer
 from sqlalchemy import literal
 from sqlalchemy import literal_column
@@ -39,6 +40,7 @@ from sqlalchemy.testing import fixtures
 from sqlalchemy.testing.assertions import eq_ignore_whitespace
 from sqlalchemy.testing.schema import Column
 from sqlalchemy.testing.schema import Table
+from sqlalchemy.types import TypeEngine
 
 
 class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
@@ -1150,6 +1152,35 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
             "t1.c2, t1.c3 INTO :ret_0, :ret_1",
         )
 
+    @testing.fixture
+    def column_expression_fixture(self):
+        class MyString(TypeEngine):
+            def column_expression(self, column):
+                return func.lower(column)
+
+        return table(
+            "some_table", column("name", String), column("value", MyString)
+        )
+
+    @testing.combinations("columns", "table", argnames="use_columns")
+    def test_plain_returning_column_expression(
+        self, column_expression_fixture, use_columns
+    ):
+        """test #8770"""
+        table1 = column_expression_fixture
+
+        if use_columns == "columns":
+            stmt = insert(table1).returning(table1)
+        else:
+            stmt = insert(table1).returning(table1.c.name, table1.c.value)
+
+        self.assert_compile(
+            stmt,
+            "INSERT INTO some_table (name, value) VALUES (:name, :value) "
+            "RETURNING some_table.name, lower(some_table.value) "
+            "INTO :ret_0, :ret_1",
+        )
+
     def test_returning_insert_computed(self):
         m = MetaData()
         t1 = Table(
index 897909b158bcc1f7143e2480a8bd1c6d44fcf7d7..0249c7952ce1aa91c3eaab3d43c995d3f199bd6b 100644 (file)
@@ -59,6 +59,7 @@ from sqlalchemy.testing.assertions import assert_raises_message
 from sqlalchemy.testing.assertions import AssertsCompiledSQL
 from sqlalchemy.testing.assertions import expect_warnings
 from sqlalchemy.testing.assertions import is_
+from sqlalchemy.types import TypeEngine
 from sqlalchemy.util import OrderedDict
 from sqlalchemy.util import u
 
@@ -205,6 +206,35 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
             dialect=dialect,
         )
 
+    @testing.fixture
+    def column_expression_fixture(self):
+        class MyString(TypeEngine):
+            def column_expression(self, column):
+                return func.lower(column)
+
+        return table(
+            "some_table", column("name", String), column("value", MyString)
+        )
+
+    @testing.combinations("columns", "table", argnames="use_columns")
+    def test_plain_returning_column_expression(
+        self, column_expression_fixture, use_columns
+    ):
+        """test #8770"""
+        table1 = column_expression_fixture
+
+        if use_columns == "columns":
+            stmt = insert(table1).returning(table1)
+        else:
+            stmt = insert(table1).returning(table1.c.name, table1.c.value)
+
+        self.assert_compile(
+            stmt,
+            "INSERT INTO some_table (name, value) "
+            "VALUES (%(name)s, %(value)s) RETURNING some_table.name, "
+            "lower(some_table.value) AS value",
+        )
+
     def test_create_drop_enum(self):
         # test escaping and unicode within CREATE TYPE for ENUM
         typ = postgresql.ENUM(
index a82b0372eaad6772b720eed26b3e8db0a6d4784f..869134f9c0c91ff88a327f7b0258ee44032aed02 100644 (file)
@@ -2,6 +2,8 @@ from sqlalchemy import bindparam
 from sqlalchemy import Boolean
 from sqlalchemy import cast
 from sqlalchemy import exc as exceptions
+from sqlalchemy import func
+from sqlalchemy import insert
 from sqlalchemy import Integer
 from sqlalchemy import literal_column
 from sqlalchemy import MetaData
@@ -32,6 +34,7 @@ from sqlalchemy.testing import is_
 from sqlalchemy.testing import mock
 from sqlalchemy.testing.schema import Column
 from sqlalchemy.testing.schema import Table
+from sqlalchemy.types import TypeEngine
 
 IDENT_LENGTH = 29
 
@@ -802,7 +805,7 @@ class ColExprLabelTest(fixtures.TestBase, AssertsCompiledSQL):
 
     """
 
-    __dialect__ = "default"
+    __dialect__ = "default_enhanced"
 
     table1 = table("some_table", column("name"), column("value"))
 
@@ -827,6 +830,101 @@ class ColExprLabelTest(fixtures.TestBase, AssertsCompiledSQL):
 
         return SomeColThing
 
+    @testing.fixture
+    def compiler_column_fixture(self):
+        return self._fixture()
+
+    @testing.fixture
+    def column_expression_fixture(self):
+        class MyString(TypeEngine):
+            def column_expression(self, column):
+                return func.lower(column)
+
+        return table(
+            "some_table", column("name", String), column("value", MyString)
+        )
+
+    def test_plain_select_compiler_expression(self, compiler_column_fixture):
+        expr = compiler_column_fixture
+        table1 = self.table1
+
+        self.assert_compile(
+            select(
+                table1.c.name,
+                expr(table1.c.value),
+            ),
+            "SELECT some_table.name, SOME_COL_THING(some_table.value) "
+            "AS value FROM some_table",
+        )
+
+    def test_plain_select_column_expression(self, column_expression_fixture):
+        table1 = column_expression_fixture
+
+        self.assert_compile(
+            select(table1),
+            "SELECT some_table.name, lower(some_table.value) AS value "
+            "FROM some_table",
+        )
+
+    def test_plain_returning_compiler_expression(
+        self, compiler_column_fixture
+    ):
+        expr = compiler_column_fixture
+        table1 = self.table1
+
+        self.assert_compile(
+            insert(table1).returning(
+                table1.c.name,
+                expr(table1.c.value),
+            ),
+            "INSERT INTO some_table (name, value) VALUES (:name, :value) "
+            "RETURNING some_table.name, "
+            "SOME_COL_THING(some_table.value) AS value",
+        )
+
+    @testing.combinations("columns", "table", argnames="use_columns")
+    def test_plain_returning_column_expression(
+        self, column_expression_fixture, use_columns
+    ):
+        table1 = column_expression_fixture
+
+        if use_columns == "columns":
+            stmt = insert(table1).returning(table1)
+        else:
+            stmt = insert(table1).returning(table1.c.name, table1.c.value)
+
+        self.assert_compile(
+            stmt,
+            "INSERT INTO some_table (name, value) VALUES (:name, :value) "
+            "RETURNING some_table.name, lower(some_table.value) AS value",
+        )
+
+    def test_select_dupes_column_expression(self, column_expression_fixture):
+        table1 = column_expression_fixture
+
+        self.assert_compile(
+            select(table1.c.name, table1.c.value, table1.c.value),
+            "SELECT some_table.name, lower(some_table.value) AS value, "
+            "lower(some_table.value) AS value__1 FROM some_table",
+        )
+
+    def test_returning_dupes_column_expression(
+        self, column_expression_fixture
+    ):
+        table1 = column_expression_fixture
+
+        stmt = insert(table1).returning(
+            table1.c.name, table1.c.value, table1.c.value
+        )
+
+        # 1.4 behavior only; limited support for labels in RETURNING
+        self.assert_compile(
+            stmt,
+            "INSERT INTO some_table (name, value) VALUES (:name, :value) "
+            "RETURNING some_table.name, lower(some_table.value) AS value, "
+            "lower(some_table.value) AS value",
+        )
+
     def test_column_auto_label_dupes_label_style_none(self):
         expr = self._fixture()
         table1 = self.table1
@@ -991,6 +1089,7 @@ class ColExprLabelTest(fixtures.TestBase, AssertsCompiledSQL):
             # not sure if this SQL is right but this is what it was
             # before the new labeling, just different label name
             "SELECT value = 0 AS value, value",
+            dialect="default",
         )
 
     def test_label_auto_label_use_labels(self):
index 10bf3beb6fe9b44576c79bd338d09076148529e1..2db9b8bc9de663936cc0394cd3b745894e1ca93f 100644 (file)
@@ -350,6 +350,50 @@ class ReturningTest(fixtures.TablesTest, AssertsExecutionResults):
             "inserted_primary_key",
         )
 
+    @testing.fixture
+    def column_expression_fixture(self, metadata, connection):
+        class MyString(TypeDecorator):
+            cache_ok = True
+            impl = String(50)
+
+            def column_expression(self, column):
+                return func.lower(column)
+
+        t1 = Table(
+            "some_table",
+            metadata,
+            Column("name", String(50)),
+            Column("value", MyString(50)),
+        )
+        metadata.create_all(connection)
+        return t1
+
+    @testing.combinations("columns", "table", argnames="use_columns")
+    def test_plain_returning_column_expression(
+        self, column_expression_fixture, use_columns, connection
+    ):
+        """test #8770"""
+        table1 = column_expression_fixture
+
+        if use_columns == "columns":
+            stmt = (
+                insert(table1)
+                .values(name="n1", value="ValUE1")
+                .returning(table1)
+            )
+        else:
+            stmt = (
+                insert(table1)
+                .values(name="n1", value="ValUE1")
+                .returning(table1.c.name, table1.c.value)
+            )
+
+        result = connection.execute(stmt)
+        row = result.first()
+
+        eq_(row._mapping["name"], "n1")
+        eq_(row._mapping["value"], "value1")
+
     @testing.fails_on_everything_except("postgresql", "firebird")
     def test_literal_returning(self, connection):
         if testing.against("postgresql"):