]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
break out text() from TextualSelect for col matching
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 19 Sep 2022 13:40:40 +0000 (09:40 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 20 Sep 2022 01:50:03 +0000 (21:50 -0400)
Fixed issue where mixing "*" with additional explicitly-named column
expressions within the columns clause of a :func:`_sql.select` construct
would cause result-column targeting to sometimes consider the label name or
other non-repeated names to be an ambiguous target.

Fixes: #8536
Change-Id: I3c845eaf571033e54c9208762344f67f4351ac3a
(cherry picked from commit 78327d98be9236c61f950526470f29b184dabba6)

doc/build/changelog/unreleased_14/8536.rst [new file with mode: 0644]
lib/sqlalchemy/engine/cursor.py
lib/sqlalchemy/engine/default.py
lib/sqlalchemy/sql/compiler.py
test/requirements.py
test/sql/test_resultset.py

diff --git a/doc/build/changelog/unreleased_14/8536.rst b/doc/build/changelog/unreleased_14/8536.rst
new file mode 100644 (file)
index 0000000..d7b5283
--- /dev/null
@@ -0,0 +1,8 @@
+.. change::
+    :tags: bug, engine
+    :tickets: 8536
+
+    Fixed issue where mixing "*" with additional explicitly-named column
+    expressions within the columns clause of a :func:`_sql.select` construct
+    would cause result-column targeting to sometimes consider the label name or
+    other non-repeated names to be an ambiguous target.
index 774916d95df6c6cbaf6c47c6179fa2a208417b5d..168e08d111401ea4481a77bf64fb4835194dcc77 100644 (file)
@@ -165,6 +165,7 @@ class CursorResultMetaData(ResultMetaData):
                 result_columns,
                 cols_are_ordered,
                 textual_ordered,
+                ad_hoc_textual,
                 loose_column_name_matching,
             ) = context.result_column_struct
             num_ctx_cols = len(result_columns)
@@ -173,6 +174,8 @@ class CursorResultMetaData(ResultMetaData):
                 cols_are_ordered
             ) = (
                 num_ctx_cols
+            ) = (
+                ad_hoc_textual
             ) = loose_column_name_matching = textual_ordered = False
 
         # merge cursor.description with the column info
@@ -184,6 +187,7 @@ class CursorResultMetaData(ResultMetaData):
             num_ctx_cols,
             cols_are_ordered,
             textual_ordered,
+            ad_hoc_textual,
             loose_column_name_matching,
         )
 
@@ -214,11 +218,18 @@ class CursorResultMetaData(ResultMetaData):
         # column keys and other names
         if num_ctx_cols:
 
-            # if by-primary-string dictionary smaller (or bigger?!) than
-            # number of columns, assume we have dupes, rewrite
-            # dupe records with "None" for index which results in
-            # ambiguous column exception when accessed.
             if len(by_key) != num_ctx_cols:
+                # if by-primary-string dictionary smaller than
+                # number of columns, assume we have dupes; (this check
+                # is also in place if string dictionary is bigger, as
+                # can occur when '*' was used as one of the compiled columns,
+                # which may or may not be suggestive of dupes), rewrite
+                # dupe records with "None" for index which results in
+                # ambiguous column exception when accessed.
+                #
+                # this is considered to be the less common case as it is not
+                # common to have dupe column keys in a SELECT statement.
+                #
                 # new in 1.4: get the complete set of all possible keys,
                 # strings, objects, whatever, that are dupes across two
                 # different records, first.
@@ -291,6 +302,7 @@ class CursorResultMetaData(ResultMetaData):
         num_ctx_cols,
         cols_are_ordered,
         textual_ordered,
+        ad_hoc_textual,
         loose_column_name_matching,
     ):
         """Merge a cursor.description with compiled result column information.
@@ -386,7 +398,9 @@ class CursorResultMetaData(ResultMetaData):
             # name-based or text-positional cases, where we need
             # to read cursor.description names
 
-            if textual_ordered:
+            if textual_ordered or (
+                ad_hoc_textual and len(cursor_description) == num_ctx_cols
+            ):
                 self._safe_for_cache = True
                 # textual positional case
                 raw_iterator = self._merge_textual_cols_by_position(
index 6b58c44696b7c5b4db9f51bc3078f7f8e5ffb1b1..e050bea7a7fb1dc7c4775a1efcfcb3b1c3b553f5 100644 (file)
@@ -975,6 +975,7 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
             compiled._result_columns,
             compiled._ordered_columns,
             compiled._textual_ordered_columns,
+            compiled._ad_hoc_textual,
             compiled._loose_column_name_matching,
         )
         self.isinsert = compiled.isinsert
index c9b6ba670c2c2368af372ecbf02b516f44e52634..0e441fbec8e6c39a78551fc7885c335342c95c57 100644 (file)
@@ -611,6 +611,20 @@ class SQLCompiler(Compiled):
     _textual_ordered_columns = False
     """tell the result object that the column names as rendered are important,
     but they are also "ordered" vs. what is in the compiled object here.
+
+    As of 1.4.42 this condition is only present when the statement is a
+    TextualSelect, e.g. text("....").columns(...), where it is required
+    that the columns are considered positionally and not by name.
+
+    """
+
+    _ad_hoc_textual = False
+    """tell the result that we encountered text() or '*' constructs in the
+    middle of the result columns, but we also have compiled columns, so
+    if the number of columns in cursor.description does not match how many
+    expressions we have, that means we can't rely on positional at all and
+    should match on name.
+
     """
 
     _ordered_columns = True
@@ -3024,7 +3038,7 @@ class SQLCompiler(Compiled):
     def _add_to_result_map(self, keyname, name, objects, type_):
         if keyname is None or keyname == "*":
             self._ordered_columns = False
-            self._textual_ordered_columns = True
+            self._ad_hoc_textual = True
         if type_._is_tuple_type:
             raise exc.CompileError(
                 "Most backends don't support SELECTing "
index 68e5f8bfe2661db8c63ac296703f9e0cc766a113..ca074c79b26f7a614d4af694ff517f7e21ffe8d1 100644 (file)
@@ -363,6 +363,17 @@ class DefaultRequirements(SuiteRequirements):
 
         return skip_if(["+pyodbc"], "no driver support")
 
+    @property
+    def select_star_mixed(self):
+        r"""target supports expressions like "SELECT x, y, \*, z FROM table"
+
+        apparently MySQL / MariaDB, Oracle doesn't handle this.
+
+        We only need a few backends so just cover SQLite / PG
+
+        """
+        return only_on(["sqlite", "postgresql"])
+
     @property
     def independent_connections(self):
         """
index 13190f915f9e28c87c58247c35399d5a84eccc5e..5d29b0b2b1fe3a6115f80648cc59428273008ccc 100644 (file)
@@ -1020,6 +1020,50 @@ class CursorResultTest(fixtures.TablesTest):
             set([True]),
         )
 
+    @testing.combinations(
+        (("name_label", "*"), False),
+        (("*", "name_label"), False),
+        (("user_id", "name_label", "user_name"), False),
+        (("user_id", "name_label", "*", "user_name"), True),
+        argnames="cols,other_cols_are_ambiguous",
+    )
+    @testing.requires.select_star_mixed
+    def test_label_against_star(
+        self, connection, cols, other_cols_are_ambiguous
+    ):
+        """test #8536"""
+        users = self.tables.users
+
+        connection.execute(users.insert(), dict(user_id=1, user_name="john"))
+
+        stmt = select(
+            *[
+                text("*")
+                if colname == "*"
+                else users.c.user_name.label("name_label")
+                if colname == "name_label"
+                else users.c[colname]
+                for colname in cols
+            ]
+        )
+
+        row = connection.execute(stmt).first()
+
+        eq_(row._mapping["name_label"], "john")
+
+        if other_cols_are_ambiguous:
+            with expect_raises_message(
+                exc.InvalidRequestError, "Ambiguous column name"
+            ):
+                row._mapping["user_id"]
+            with expect_raises_message(
+                exc.InvalidRequestError, "Ambiguous column name"
+            ):
+                row._mapping["user_name"]
+        else:
+            eq_(row._mapping["user_id"], 1)
+            eq_(row._mapping["user_name"], "john")
+
     def test_loose_matching_one(self, connection):
         users = self.tables.users
         addresses = self.tables.addresses