]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
establish consistency for RETURNING column labels
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 7 Nov 2022 23:40:03 +0000 (18:40 -0500)
committermike bayer <mike_mp@zzzcomputing.com>
Fri, 11 Nov 2022 16:20:00 +0000 (16:20 +0000)
The RETURNING clause now renders columns using the routine as that of the
:class:`.Select` to generate labels, which will include disambiguating
labels, as well as that a SQL function surrounding a named column will be
labeled using the column name itself. This is a more comprehensive change
than a similar one made for the 1.4 series that adjusted the function label
issue only.

includes 1.4's changelog for the backported version which also
fixes an Oracle issue independently of the 2.0 series.

Fixes: #8770
Change-Id: I2ab078a214a778ffe1720dbd864ae4c105a0691d

doc/build/changelog/unreleased_14/8770.rst [new file with mode: 0644]
doc/build/changelog/unreleased_20/8770.rst [new file with mode: 0644]
lib/sqlalchemy/dialects/mssql/base.py
lib/sqlalchemy/sql/compiler.py
lib/sqlalchemy/sql/dml.py
lib/sqlalchemy/sql/selectable.py
test/dialect/mssql/test_compiler.py
test/dialect/oracle/test_compiler.py
test/dialect/postgresql/test_compiler.py
test/sql/test_labels.py
test/sql/test_returning.py

diff --git a/doc/build/changelog/unreleased_14/8770.rst b/doc/build/changelog/unreleased_14/8770.rst
new file mode 100644 (file)
index 0000000..8968b03
--- /dev/null
@@ -0,0 +1,23 @@
+.. change::
+    :tags: bug, postgresql, mssql
+    :tickets: 8770
+
+    For the PostgreSQL and SQL Server dialects only, adjusted the compiler so
+    that when rendering column expressions in the RETURNING clause, the "non
+    anon" label that's used in SELECT statements is suggested for SQL
+    expression elements that generate a label; the primary example is a SQL
+    function that may be emitting as part of the column's type, where the label
+    name should match the column's name by default. This restores a not-well
+    defined behavior that had changed in version 1.4.21 due to :ticket:`6718`,
+    :ticket:`6710`. The Oracle dialect has a different RETURNING implementation
+    and was not affected by this issue. Version 2.0 features an across the
+    board change for its widely expanded support of RETURNING on other
+    backends.
+
+
+.. change::
+    :tags: bug, oracle
+
+    Fixed issue in the Oracle dialect where an INSERT statement that used
+    ``insert(some_table).values(...).returning(some_table)`` against a full
+    :class:`.Table` object at once would fail to execute, raising an exception.
diff --git a/doc/build/changelog/unreleased_20/8770.rst b/doc/build/changelog/unreleased_20/8770.rst
new file mode 100644 (file)
index 0000000..59b94d6
--- /dev/null
@@ -0,0 +1,10 @@
+.. change::
+    :tags: bug, sql
+    :tickets: 8770
+
+    The RETURNING clause now renders columns using the routine as that of the
+    :class:`.Select` to generate labels, which will include disambiguating
+    labels, as well as that a SQL function surrounding a named column will be
+    labeled using the column name itself. This is a more comprehensive change
+    than a similar one made for the 1.4 series that adjusted the function label
+    issue only.
index a338ba27af136838612155f2cd60ddfd3c02665f..53fe96c9ae381093b2fa820fef683563ff4ae57a 100644 (file)
@@ -2295,11 +2295,24 @@ class MSSQLCompiler(compiler.SQLCompiler):
         columns = [
             self._label_returning_column(
                 stmt,
-                adapter.traverse(c),
+                adapter.traverse(column),
                 populate_result_map,
-                {"result_map_targets": (c,)},
+                {"result_map_targets": (column,)},
+                fallback_label_name=fallback_label_name,
+                column_is_repeated=repeated,
+                name=name,
+                proxy_name=proxy_name,
+                **kw,
+            )
+            for (
+                name,
+                proxy_name,
+                fallback_label_name,
+                column,
+                repeated,
+            ) in stmt._generate_columns_plus_names(
+                True, cols=expression._select_iterables(returning_cols)
             )
-            for c in expression._select_iterables(returning_cols)
         ]
 
         return "OUTPUT " + ", ".join(columns)
index 3e62cb3505d0fbfc1381d040d5c18b5d626b1706..97397e9cf4337bf6f1d385a607246ee6ec29f814 100644 (file)
@@ -3760,7 +3760,6 @@ class SQLCompiler(Compiled):
             "_label_select_column is only relevant within "
             "the columns clause of a SELECT or RETURNING"
         )
-
         if isinstance(column, elements.Label):
             if col_expr is not column:
                 result_expr = _CompileLabel(
@@ -4416,9 +4415,27 @@ class SQLCompiler(Compiled):
         populate_result_map: bool,
         **kw: Any,
     ) -> str:
+
         columns = [
-            self._label_returning_column(stmt, c, populate_result_map, **kw)
-            for c in base._select_iterables(returning_cols)
+            self._label_returning_column(
+                stmt,
+                column,
+                populate_result_map,
+                fallback_label_name=fallback_label_name,
+                column_is_repeated=repeated,
+                name=name,
+                proxy_name=proxy_name,
+                **kw,
+            )
+            for (
+                name,
+                proxy_name,
+                fallback_label_name,
+                column,
+                repeated,
+            ) in stmt._generate_columns_plus_names(
+                True, cols=base._select_iterables(returning_cols)
+            )
         ]
 
         return "RETURNING " + ", ".join(columns)
index 5145a4a16af43af528e5421c507156106699a4ad..2d3e3598b8d36f7e246e9d5feabfbfb69ee980a1 100644 (file)
@@ -59,6 +59,7 @@ from .selectable import FromClause
 from .selectable import HasCTE
 from .selectable import HasPrefixes
 from .selectable import Join
+from .selectable import SelectLabelStyle
 from .selectable import TableClause
 from .selectable import TypedReturnsRows
 from .sqltypes import NullType
@@ -399,6 +400,9 @@ class UpdateBase(
     ] = util.EMPTY_DICT
     named_with_column = False
 
+    _label_style: SelectLabelStyle = (
+        SelectLabelStyle.LABEL_STYLE_DISAMBIGUATE_ONLY
+    )
     table: _DMLTableElement
 
     _return_defaults = False
index 9de015774df3dddab73dfb4b6f956c1164046e54..488dfe721c3f0d400601e25b4d21d4b2989e4a1c 100644 (file)
@@ -2193,7 +2193,9 @@ class SelectsRows(ReturnsRows):
     _label_style: SelectLabelStyle = LABEL_STYLE_NONE
 
     def _generate_columns_plus_names(
-        self, anon_for_dupe_key: bool
+        self,
+        anon_for_dupe_key: bool,
+        cols: Optional[_SelectIterable] = None,
     ) -> List[_ColumnsPlusNames]:
         """Generate column names as rendered in a SELECT statement by
         the compiler.
@@ -2204,7 +2206,9 @@ class SelectsRows(ReturnsRows):
         _column_naming_convention as well.
 
         """
-        cols = self._all_selected_columns
+
+        if cols is None:
+            cols = self._all_selected_columns
 
         key_naming_convention = SelectState._column_naming_convention(
             self._label_style
index 8605ea9c0526c4e3cb7b3ddae7187854642312d0..b575595ac29367949d0d12de8e2d689750f4218f 100644 (file)
@@ -36,6 +36,7 @@ from sqlalchemy.testing import eq_
 from sqlalchemy.testing import fixtures
 from sqlalchemy.testing import is_
 from sqlalchemy.testing.assertions import eq_ignore_whitespace
+from sqlalchemy.types import TypeEngine
 
 tbl = table("t", column("a"))
 
@@ -119,6 +120,34 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
             "Latin1_General_CS_AS_KS_WS_CI ASC",
         )
 
+    @testing.fixture
+    def column_expression_fixture(self):
+        class MyString(TypeEngine):
+            def column_expression(self, column):
+                return func.lower(column)
+
+        return table(
+            "some_table", column("name", String), column("value", MyString)
+        )
+
+    @testing.combinations("columns", "table", argnames="use_columns")
+    def test_plain_returning_column_expression(
+        self, column_expression_fixture, use_columns
+    ):
+        """test #8770"""
+        table1 = column_expression_fixture
+
+        if use_columns == "columns":
+            stmt = insert(table1).returning(table1)
+        else:
+            stmt = insert(table1).returning(table1.c.name, table1.c.value)
+
+        self.assert_compile(
+            stmt,
+            "INSERT INTO some_table (name, value) OUTPUT inserted.name, "
+            "lower(inserted.value) AS value VALUES (:name, :value)",
+        )
+
     def test_join_with_hint(self):
         t1 = table(
             "t1",
index 2973c6e39d4d13d502464fb06a8227fa526f7d5c..8981e74e8c0f9ba0363d4876523091e1f168c9a7 100644 (file)
@@ -9,6 +9,7 @@ from sqlalchemy import ForeignKey
 from sqlalchemy import func
 from sqlalchemy import Identity
 from sqlalchemy import Index
+from sqlalchemy import insert
 from sqlalchemy import Integer
 from sqlalchemy import literal
 from sqlalchemy import literal_column
@@ -42,6 +43,7 @@ from sqlalchemy.testing import fixtures
 from sqlalchemy.testing.assertions import eq_ignore_whitespace
 from sqlalchemy.testing.schema import Column
 from sqlalchemy.testing.schema import Table
+from sqlalchemy.types import TypeEngine
 
 
 class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
@@ -1359,6 +1361,35 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
             "t1.c2, t1.c3 INTO :ret_0, :ret_1",
         )
 
+    @testing.fixture
+    def column_expression_fixture(self):
+        class MyString(TypeEngine):
+            def column_expression(self, column):
+                return func.lower(column)
+
+        return table(
+            "some_table", column("name", String), column("value", MyString)
+        )
+
+    @testing.combinations("columns", "table", argnames="use_columns")
+    def test_plain_returning_column_expression(
+        self, column_expression_fixture, use_columns
+    ):
+        """test #8770"""
+        table1 = column_expression_fixture
+
+        if use_columns == "columns":
+            stmt = insert(table1).returning(table1)
+        else:
+            stmt = insert(table1).returning(table1.c.name, table1.c.value)
+
+        self.assert_compile(
+            stmt,
+            "INSERT INTO some_table (name, value) VALUES (:name, :value) "
+            "RETURNING some_table.name, lower(some_table.value) "
+            "INTO :ret_0, :ret_1",
+        )
+
     def test_returning_insert_computed(self):
         m = MetaData()
         t1 = Table(
index 96a8e7d5a585e9b729f609d921bf70160d8b0e11..338d0da4ea40d8e3ee0649f25d457719e4147dd7 100644 (file)
@@ -61,6 +61,7 @@ from sqlalchemy.testing.assertions import AssertsCompiledSQL
 from sqlalchemy.testing.assertions import eq_ignore_whitespace
 from sqlalchemy.testing.assertions import expect_warnings
 from sqlalchemy.testing.assertions import is_
+from sqlalchemy.types import TypeEngine
 from sqlalchemy.util import OrderedDict
 
 
@@ -200,6 +201,35 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
             dialect=dialect,
         )
 
+    @testing.fixture
+    def column_expression_fixture(self):
+        class MyString(TypeEngine):
+            def column_expression(self, column):
+                return func.lower(column)
+
+        return table(
+            "some_table", column("name", String), column("value", MyString)
+        )
+
+    @testing.combinations("columns", "table", argnames="use_columns")
+    def test_plain_returning_column_expression(
+        self, column_expression_fixture, use_columns
+    ):
+        """test #8770"""
+        table1 = column_expression_fixture
+
+        if use_columns == "columns":
+            stmt = insert(table1).returning(table1)
+        else:
+            stmt = insert(table1).returning(table1.c.name, table1.c.value)
+
+        self.assert_compile(
+            stmt,
+            "INSERT INTO some_table (name, value) "
+            "VALUES (%(name)s, %(value)s) RETURNING some_table.name, "
+            "lower(some_table.value) AS value",
+        )
+
     def test_create_drop_enum(self):
         # test escaping and unicode within CREATE TYPE for ENUM
         typ = postgresql.ENUM("val1", "val2", "val's 3", "méil", name="myname")
index 42d9c5f0039b1546a3ed713df3eca562c6170cea..a74c5811c324fe175818295f6afe3cc85355ec29 100644 (file)
@@ -2,6 +2,8 @@ from sqlalchemy import bindparam
 from sqlalchemy import Boolean
 from sqlalchemy import cast
 from sqlalchemy import exc as exceptions
+from sqlalchemy import func
+from sqlalchemy import insert
 from sqlalchemy import Integer
 from sqlalchemy import literal_column
 from sqlalchemy import MetaData
@@ -32,6 +34,7 @@ from sqlalchemy.testing import is_
 from sqlalchemy.testing import mock
 from sqlalchemy.testing.schema import Column
 from sqlalchemy.testing.schema import Table
+from sqlalchemy.types import TypeEngine
 
 IDENT_LENGTH = 29
 
@@ -827,6 +830,100 @@ class ColExprLabelTest(fixtures.TestBase, AssertsCompiledSQL):
 
         return SomeColThing
 
+    @testing.fixture
+    def compiler_column_fixture(self):
+        return self._fixture()
+
+    @testing.fixture
+    def column_expression_fixture(self):
+        class MyString(TypeEngine):
+            def column_expression(self, column):
+                return func.lower(column)
+
+        return table(
+            "some_table", column("name", String), column("value", MyString)
+        )
+
+    def test_plain_select_compiler_expression(self, compiler_column_fixture):
+        expr = compiler_column_fixture
+        table1 = self.table1
+
+        self.assert_compile(
+            select(
+                table1.c.name,
+                expr(table1.c.value),
+            ),
+            "SELECT some_table.name, SOME_COL_THING(some_table.value) "
+            "AS value FROM some_table",
+        )
+
+    def test_plain_select_column_expression(self, column_expression_fixture):
+        table1 = column_expression_fixture
+
+        self.assert_compile(
+            select(table1),
+            "SELECT some_table.name, lower(some_table.value) AS value "
+            "FROM some_table",
+        )
+
+    def test_plain_returning_compiler_expression(
+        self, compiler_column_fixture
+    ):
+        expr = compiler_column_fixture
+        table1 = self.table1
+
+        self.assert_compile(
+            insert(table1).returning(
+                table1.c.name,
+                expr(table1.c.value),
+            ),
+            "INSERT INTO some_table (name, value) VALUES (:name, :value) "
+            "RETURNING some_table.name, "
+            "SOME_COL_THING(some_table.value) AS value",
+        )
+
+    @testing.combinations("columns", "table", argnames="use_columns")
+    def test_plain_returning_column_expression(
+        self, column_expression_fixture, use_columns
+    ):
+        table1 = column_expression_fixture
+
+        if use_columns == "columns":
+            stmt = insert(table1).returning(table1)
+        else:
+            stmt = insert(table1).returning(table1.c.name, table1.c.value)
+
+        self.assert_compile(
+            stmt,
+            "INSERT INTO some_table (name, value) VALUES (:name, :value) "
+            "RETURNING some_table.name, lower(some_table.value) AS value",
+        )
+
+    def test_select_dupes_column_expression(self, column_expression_fixture):
+        table1 = column_expression_fixture
+
+        self.assert_compile(
+            select(table1.c.name, table1.c.value, table1.c.value),
+            "SELECT some_table.name, lower(some_table.value) AS value, "
+            "lower(some_table.value) AS value__1 FROM some_table",
+        )
+
+    def test_returning_dupes_column_expression(
+        self, column_expression_fixture
+    ):
+        table1 = column_expression_fixture
+
+        stmt = insert(table1).returning(
+            table1.c.name, table1.c.value, table1.c.value
+        )
+
+        self.assert_compile(
+            stmt,
+            "INSERT INTO some_table (name, value) VALUES (:name, :value) "
+            "RETURNING some_table.name, lower(some_table.value) AS value, "
+            "lower(some_table.value) AS value__1",
+        )
+
     def test_column_auto_label_dupes_label_style_none(self):
         expr = self._fixture()
         table1 = self.table1
index 32d4c7740dc187bb685656953f142fc629f1128c..e0299e334a65e763b9f2d1ce516353d260a213a0 100644 (file)
@@ -415,6 +415,50 @@ class InsertReturningTest(fixtures.TablesTest, AssertsExecutionResults):
         result = connection.execute(ins)
         eq_(result.fetchall(), [(1, 1), (2, 2), (3, 3)])
 
+    @testing.fixture
+    def column_expression_fixture(self, metadata, connection):
+        class MyString(TypeDecorator):
+            cache_ok = True
+            impl = String(50)
+
+            def column_expression(self, column):
+                return func.lower(column)
+
+        t1 = Table(
+            "some_table",
+            metadata,
+            Column("name", String(50)),
+            Column("value", MyString(50)),
+        )
+        metadata.create_all(connection)
+        return t1
+
+    @testing.combinations("columns", "table", argnames="use_columns")
+    def test_plain_returning_column_expression(
+        self, column_expression_fixture, use_columns, connection
+    ):
+        """test #8770"""
+        table1 = column_expression_fixture
+
+        if use_columns == "columns":
+            stmt = (
+                insert(table1)
+                .values(name="n1", value="ValUE1")
+                .returning(table1)
+            )
+        else:
+            stmt = (
+                insert(table1)
+                .values(name="n1", value="ValUE1")
+                .returning(table1.c.name, table1.c.value)
+            )
+
+        result = connection.execute(stmt)
+        row = result.first()
+
+        eq_(row._mapping["name"], "n1")
+        eq_(row._mapping["value"], "value1")
+
     @testing.fails_on_everything_except(
         "postgresql", "mariadb>=10.5", "sqlite>=3.34"
     )