]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
use driver col names
authorMike Bayer <mike_mp@zzzcomputing.com>
Thu, 21 Dec 2023 17:14:00 +0000 (12:14 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Thu, 4 Jul 2024 14:50:19 +0000 (10:50 -0400)
Added new execution option
:paramref:`_engine.Connection.execution_options.driver_column_names`. This
option disables the "name normalize" step that takes place against the
DBAPI ``cursor.description`` for uppercase-default backends like Oracle,
and will cause the keys of a result set (e.g. named tuple names, dictionary
keys in :attr:`.Row._mapping`, etc.) to be exactly what was delivered in
cursor.description.   This is mostly useful for plain textual statements
using :func:`_sql.text` or :meth:`_engine.Connection.exec_driver_sql`.

Fixes: #10789
Change-Id: Ib647b25bb53492fa839af04dd032d9f061e630af

doc/build/changelog/unreleased_21/10789.rst [new file with mode: 0644]
lib/sqlalchemy/dialects/oracle/base.py
lib/sqlalchemy/engine/base.py
lib/sqlalchemy/engine/cursor.py
lib/sqlalchemy/engine/interfaces.py
lib/sqlalchemy/ext/asyncio/engine.py
lib/sqlalchemy/orm/query.py
lib/sqlalchemy/sql/base.py
lib/sqlalchemy/testing/suite/test_results.py
test/sql/test_types.py
test/typing/test_overloads.py

diff --git a/doc/build/changelog/unreleased_21/10789.rst b/doc/build/changelog/unreleased_21/10789.rst
new file mode 100644 (file)
index 0000000..af3b301
--- /dev/null
@@ -0,0 +1,12 @@
+.. change::
+    :tags: usecase, engine
+    :tickets: 10789
+
+    Added new execution option
+    :paramref:`_engine.Connection.execution_options.driver_column_names`. This
+    option disables the "name normalize" step that takes place against the
+    DBAPI ``cursor.description`` for uppercase-default backends like Oracle,
+    and will cause the keys of a result set (e.g. named tuple names, dictionary
+    keys in :attr:`.Row._mapping`, etc.) to be exactly what was delivered in
+    cursor.description.   This is mostly useful for plain textual statements
+    using :func:`_sql.text` or :meth:`_engine.Connection.exec_driver_sql`.
index 4f180cbd9e74f5974becbe521d5dffaff876b348..dc347f0d798844a44b476b510872147ad8f0d8a1 100644 (file)
@@ -50,7 +50,7 @@ The :class:`_schema.Identity` object support many options to control the
 incrementing value, etc.
 In addition to the standard options, Oracle supports setting
 :paramref:`_schema.Identity.always` to ``None`` to use the default
-generated mode, rendering GENERATED AS IDENTITY in the DDL. 
+generated mode, rendering GENERATED AS IDENTITY in the DDL.
 Oracle also supports two custom options specified using dialect kwargs:
 
 * ``oracle_on_null``: when set to ``True`` renders ``ON NULL`` in conjunction
index 3451a824476cb74e669743433f574bf6cfc2b58c..e5557f7d2849705f09ebb93111ebc22b7040590f 100644 (file)
@@ -252,6 +252,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
         insertmanyvalues_page_size: int = ...,
         schema_translate_map: Optional[SchemaTranslateMapType] = ...,
         preserve_rowcount: bool = False,
+        driver_column_names: bool = False,
         **opt: Any,
     ) -> Connection: ...
 
@@ -515,6 +516,18 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
             :ref:`orm_queryguide_execution_options` - documentation on all
             ORM-specific execution options
 
+        :param driver_column_names: When True, the returned
+         :class:`_engine.CursorResult` will use the column names as written in
+         ``cursor.description`` to set up the keys for the result set,
+         including the names of columns for the :class:`_engine.Row` object as
+         well as the dictionary keys when using :attr:`_engine.Row._mapping`.
+         On backends that use "name normalization" such as Oracle to correct
+         for lower case names being converted to all uppercase, this behavior
+         is turned off and the raw UPPERCASE names in cursor.description will
+         be present.
+
+         .. versionadded:: 2.1
+
         """  # noqa
         if self._has_events or self.engine._has_events:
             self.dispatch.set_connection_execution_options(self, opt)
index 3a58e71a935e607e8861aa0d6f70853248d513b8..9ff5cdeb86e956ffc5bd3f840d05adf5ca0e0d01 100644 (file)
@@ -187,7 +187,7 @@ class CursorResultMetaData(ResultMetaData):
         translated_indexes: Optional[List[int]],
         safe_for_cache: bool,
         keymap_by_result_column_idx: Any,
-    ) -> CursorResultMetaData:
+    ) -> Self:
         new_obj = self.__class__.__new__(self.__class__)
         new_obj._unpickled = unpickled
         new_obj._processors = processors
@@ -200,7 +200,7 @@ class CursorResultMetaData(ResultMetaData):
         new_obj._key_to_index = self._make_key_to_index(keymap, MD_INDEX)
         return new_obj
 
-    def _remove_processors(self) -> CursorResultMetaData:
+    def _remove_processors(self) -> Self:
         assert not self._tuplefilter
         return self._make_new_metadata(
             unpickled=self._unpickled,
@@ -216,9 +216,7 @@ class CursorResultMetaData(ResultMetaData):
             keymap_by_result_column_idx=self._keymap_by_result_column_idx,
         )
 
-    def _splice_horizontally(
-        self, other: CursorResultMetaData
-    ) -> CursorResultMetaData:
+    def _splice_horizontally(self, other: CursorResultMetaData) -> Self:
         assert not self._tuplefilter
 
         keymap = dict(self._keymap)
@@ -252,7 +250,7 @@ class CursorResultMetaData(ResultMetaData):
             },
         )
 
-    def _reduce(self, keys: Sequence[_KeyIndexType]) -> ResultMetaData:
+    def _reduce(self, keys: Sequence[_KeyIndexType]) -> Self:
         recs = list(self._metadata_for_keys(keys))
 
         indexes = [rec[MD_INDEX] for rec in recs]
@@ -284,7 +282,7 @@ class CursorResultMetaData(ResultMetaData):
             keymap_by_result_column_idx=self._keymap_by_result_column_idx,
         )
 
-    def _adapt_to_context(self, context: ExecutionContext) -> ResultMetaData:
+    def _adapt_to_context(self, context: ExecutionContext) -> Self:
         """When using a cached Compiled construct that has a _result_map,
         for a new statement that used the cached Compiled, we need to ensure
         the keymap has the Column objects from our new statement as keys.
@@ -350,6 +348,8 @@ class CursorResultMetaData(ResultMetaData):
         self,
         parent: CursorResult[Unpack[TupleAny]],
         cursor_description: _DBAPICursorDescription,
+        *,
+        driver_column_names: bool = False,
     ):
         context = parent.context
         self._tuplefilter = None
@@ -383,6 +383,7 @@ class CursorResultMetaData(ResultMetaData):
             textual_ordered,
             ad_hoc_textual,
             loose_column_name_matching,
+            driver_column_names,
         )
 
         # processors in key order which are used when building up
@@ -474,15 +475,20 @@ class CursorResultMetaData(ResultMetaData):
                 for metadata_entry in raw
             }
 
-        # update keymap with "translated" names.  In SQLAlchemy this is a
-        # sqlite only thing, and in fact impacting only extremely old SQLite
-        # versions unlikely to be present in modern Python versions.
-        # however, the pyhive third party dialect is
-        # also using this hook, which means others still might use it as well.
-        # I dislike having this awkward hook here but as long as we need
-        # to use names in cursor.description in some cases we need to have
-        # some hook to accomplish this.
-        if not num_ctx_cols and context._translate_colname:
+        # update keymap with "translated" names.
+        # the "translated" name thing has a long history:
+        # 1. originally, it was used to fix an issue in very old SQLite
+        #    versions prior to 3.10.0.   This code is still there in the
+        #    sqlite dialect.
+        # 2. Next, the pyhive third party dialect started using this hook
+        #    for some driver related issue on their end.
+        # 3. Most recently, the "driver_column_names" execution option has
+        #    taken advantage of this hook to get raw DBAPI col names in the
+        #    result keys without disrupting the usual merge process.
+
+        if driver_column_names or (
+            not num_ctx_cols and context._translate_colname
+        ):
             self._keymap.update(
                 {
                     metadata_entry[MD_UNTRANSLATED]: self._keymap[
@@ -505,6 +511,7 @@ class CursorResultMetaData(ResultMetaData):
         textual_ordered,
         ad_hoc_textual,
         loose_column_name_matching,
+        driver_column_names,
     ):
         """Merge a cursor.description with compiled result column information.
 
@@ -566,6 +573,7 @@ class CursorResultMetaData(ResultMetaData):
             and cols_are_ordered
             and not textual_ordered
             and num_ctx_cols == len(cursor_description)
+            and not driver_column_names
         ):
             self._keys = [elem[0] for elem in result_columns]
             # pure positional 1-1 case; doesn't need to read
@@ -573,9 +581,11 @@ class CursorResultMetaData(ResultMetaData):
 
             # most common case for Core and ORM
 
-            # this metadata is safe to cache because we are guaranteed
+            # this metadata is safe to
+            # cache because we are guaranteed
             # to have the columns in the same order for new executions
             self._safe_for_cache = True
+
             return [
                 (
                     idx,
@@ -599,10 +609,13 @@ class CursorResultMetaData(ResultMetaData):
             if textual_ordered or (
                 ad_hoc_textual and len(cursor_description) == num_ctx_cols
             ):
-                self._safe_for_cache = True
+                self._safe_for_cache = not driver_column_names
                 # textual positional case
                 raw_iterator = self._merge_textual_cols_by_position(
-                    context, cursor_description, result_columns
+                    context,
+                    cursor_description,
+                    result_columns,
+                    driver_column_names,
                 )
             elif num_ctx_cols:
                 # compiled SQL with a mismatch of description cols
@@ -615,13 +628,14 @@ class CursorResultMetaData(ResultMetaData):
                     cursor_description,
                     result_columns,
                     loose_column_name_matching,
+                    driver_column_names,
                 )
             else:
                 # no compiled SQL, just a raw string, order of columns
                 # can change for "select *"
                 self._safe_for_cache = False
                 raw_iterator = self._merge_cols_by_none(
-                    context, cursor_description
+                    context, cursor_description, driver_column_names
                 )
 
             return [
@@ -647,39 +661,53 @@ class CursorResultMetaData(ResultMetaData):
                 ) in raw_iterator
             ]
 
-    def _colnames_from_description(self, context, cursor_description):
+    def _colnames_from_description(
+        self, context, cursor_description, driver_column_names
+    ):
         """Extract column names and data types from a cursor.description.
 
         Applies unicode decoding, column translation, "normalization",
         and case sensitivity rules to the names based on the dialect.
 
         """
-
         dialect = context.dialect
         translate_colname = context._translate_colname
         normalize_name = (
             dialect.normalize_name if dialect.requires_name_normalize else None
         )
-        untranslated = None
 
         self._keys = []
 
+        untranslated = None
+
         for idx, rec in enumerate(cursor_description):
-            colname = rec[0]
+            colname = unnormalized = rec[0]
             coltype = rec[1]
 
             if translate_colname:
+                # a None here for "untranslated" means "the dialect did not
+                # change the column name and the untranslated case can be
+                # ignored".  otherwise "untranslated" is expected to be the
+                # original, unchanged colname (e.g. is == to "unnormalized")
                 colname, untranslated = translate_colname(colname)
 
+                assert untranslated is None or untranslated == unnormalized
+
             if normalize_name:
                 colname = normalize_name(colname)
 
-            self._keys.append(colname)
+            if driver_column_names:
+                self._keys.append(unnormalized)
 
-            yield idx, colname, untranslated, coltype
+                yield idx, colname, unnormalized, coltype
+
+            else:
+                self._keys.append(colname)
+
+                yield idx, colname, untranslated, coltype
 
     def _merge_textual_cols_by_position(
-        self, context, cursor_description, result_columns
+        self, context, cursor_description, result_columns, driver_column_names
     ):
         num_ctx_cols = len(result_columns)
 
@@ -696,7 +724,9 @@ class CursorResultMetaData(ResultMetaData):
             colname,
             untranslated,
             coltype,
-        ) in self._colnames_from_description(context, cursor_description):
+        ) in self._colnames_from_description(
+            context, cursor_description, driver_column_names
+        ):
             if idx < num_ctx_cols:
                 ctx_rec = result_columns[idx]
                 obj = ctx_rec[RM_OBJECTS]
@@ -720,6 +750,7 @@ class CursorResultMetaData(ResultMetaData):
         cursor_description,
         result_columns,
         loose_column_name_matching,
+        driver_column_names,
     ):
         match_map = self._create_description_match_map(
             result_columns, loose_column_name_matching
@@ -731,7 +762,9 @@ class CursorResultMetaData(ResultMetaData):
             colname,
             untranslated,
             coltype,
-        ) in self._colnames_from_description(context, cursor_description):
+        ) in self._colnames_from_description(
+            context, cursor_description, driver_column_names
+        ):
             try:
                 ctx_rec = match_map[colname]
             except KeyError:
@@ -771,6 +804,7 @@ class CursorResultMetaData(ResultMetaData):
         ] = {}
         for ridx, elem in enumerate(result_columns):
             key = elem[RM_RENDERED_NAME]
+
             if key in d:
                 # conflicting keyname - just add the column-linked objects
                 # to the existing record.  if there is a duplicate column
@@ -794,13 +828,17 @@ class CursorResultMetaData(ResultMetaData):
                     )
         return d
 
-    def _merge_cols_by_none(self, context, cursor_description):
+    def _merge_cols_by_none(
+        self, context, cursor_description, driver_column_names
+    ):
         for (
             idx,
             colname,
             untranslated,
             coltype,
-        ) in self._colnames_from_description(context, cursor_description):
+        ) in self._colnames_from_description(
+            context, cursor_description, driver_column_names
+        ):
             yield (
                 idx,
                 None,
@@ -1489,10 +1527,20 @@ class CursorResult(Result[Unpack[_Ts]]):
             self._metadata = self._no_result_metadata
 
     def _init_metadata(self, context, cursor_description):
+        driver_column_names = context.execution_options.get(
+            "driver_column_names", False
+        )
         if context.compiled:
             compiled = context.compiled
 
-            if compiled._cached_metadata:
+            metadata: CursorResultMetaData
+
+            if driver_column_names:
+                metadata = CursorResultMetaData(
+                    self, cursor_description, driver_column_names=True
+                )
+                assert not metadata._safe_for_cache
+            elif compiled._cached_metadata:
                 metadata = compiled._cached_metadata
             else:
                 metadata = CursorResultMetaData(self, cursor_description)
@@ -1527,7 +1575,9 @@ class CursorResult(Result[Unpack[_Ts]]):
 
         else:
             self._metadata = metadata = CursorResultMetaData(
-                self, cursor_description
+                self,
+                cursor_description,
+                driver_column_names=driver_column_names,
             )
         if self._echo:
             context.connection._log_debug(
index d4c5aef7976ab91fc7195082978da7c7b17bb510..52821b0ca109fcd99b446b14dad849a09275db86 100644 (file)
@@ -271,6 +271,7 @@ class _CoreKnownExecutionOptions(TypedDict, total=False):
     insertmanyvalues_page_size: int
     schema_translate_map: Optional[SchemaTranslateMapType]
     preserve_rowcount: bool
+    driver_column_names: bool
 
 
 _ExecuteOptions = immutabledict[str, Any]
index 16d14ef5dbe6effa06b4ee92f94ce614b6efd456..0b572d426a274e6d13d329fff30c33c916e08635 100644 (file)
@@ -421,6 +421,7 @@ class AsyncConnection(
         insertmanyvalues_page_size: int = ...,
         schema_translate_map: Optional[SchemaTranslateMapType] = ...,
         preserve_rowcount: bool = False,
+        driver_column_names: bool = False,
         **opt: Any,
     ) -> AsyncConnection: ...
 
index bfc0fb36527c1fc226991391344e602d5f0121fa..b535b9db2d2f611c45f6b7e56d3aaab0c892adf4 100644 (file)
@@ -1728,6 +1728,7 @@ class Query(
         stream_results: bool = False,
         max_row_buffer: int = ...,
         yield_per: int = ...,
+        driver_column_names: bool = ...,
         insertmanyvalues_page_size: int = ...,
         schema_translate_map: Optional[SchemaTranslateMapType] = ...,
         populate_existing: bool = False,
index 96a9337f48c447e664326eb44300f4cfec85f901..dcb00e16a52b63e3c555627fd202b896cbce6e8e 100644 (file)
@@ -1157,6 +1157,7 @@ class Executable(roles.StatementRole):
         stream_results: bool = False,
         max_row_buffer: int = ...,
         yield_per: int = ...,
+        driver_column_names: bool = ...,
         insertmanyvalues_page_size: int = ...,
         schema_translate_map: Optional[SchemaTranslateMapType] = ...,
         populate_existing: bool = False,
index b3f432fb76c5772a7fcbe79c2ae940fba64f2b52..05e35d0ebf31a63d3f460b32a3c9b916f2933d16 100644 (file)
@@ -17,6 +17,7 @@ from ..schema import Table
 from ... import DateTime
 from ... import func
 from ... import Integer
+from ... import quoted_name
 from ... import select
 from ... import sql
 from ... import String
@@ -118,6 +119,165 @@ class RowFetchTest(fixtures.TablesTest):
         eq_(row.somelabel, datetime.datetime(2006, 5, 12, 12, 0, 0))
 
 
+class NameDenormalizeTest(fixtures.TablesTest):
+    __backend__ = True
+
+    @classmethod
+    def define_tables(cls, metadata):
+        cls.tables.denormalize_table = Table(
+            "denormalize_table",
+            metadata,
+            Column("id", Integer, primary_key=True),
+            Column("all_lowercase", Integer),
+            Column("ALL_UPPERCASE", Integer),
+            Column("MixedCase", Integer),
+            Column(quoted_name("all_lowercase_quoted", quote=True), Integer),
+            Column(quoted_name("ALL_UPPERCASE_QUOTED", quote=True), Integer),
+        )
+
+    @classmethod
+    def insert_data(cls, connection):
+        connection.execute(
+            cls.tables.denormalize_table.insert(),
+            {
+                "id": 1,
+                "all_lowercase": 5,
+                "ALL_UPPERCASE": 6,
+                "MixedCase": 7,
+                "all_lowercase_quoted": 8,
+                "ALL_UPPERCASE_QUOTED": 9,
+            },
+        )
+
+    def _assert_row_mapping(self, row, mapping, include_cols=None):
+        eq_(row._mapping, mapping)
+
+        for k in mapping:
+            eq_(row._mapping[k], mapping[k])
+            eq_(getattr(row, k), mapping[k])
+
+        for idx, k in enumerate(mapping):
+            eq_(row[idx], mapping[k])
+
+        if include_cols:
+            for col, (idx, k) in zip(include_cols, enumerate(mapping)):
+                eq_(row._mapping[col], mapping[k])
+
+    @testing.variation(
+        "stmt_type", ["driver_sql", "text_star", "core_select", "text_cols"]
+    )
+    @testing.variation("use_driver_cols", [True, False])
+    def test_cols_driver_cols(self, connection, stmt_type, use_driver_cols):
+        if stmt_type.driver_sql or stmt_type.text_star or stmt_type.text_cols:
+            stmt = select("*").select_from(self.tables.denormalize_table)
+            text_stmt = str(stmt.compile(connection))
+
+            if stmt_type.text_star or stmt_type.text_cols:
+                stmt = text(text_stmt)
+
+                if stmt_type.text_cols:
+                    stmt = stmt.columns(*self.tables.denormalize_table.c)
+        elif stmt_type.core_select:
+            stmt = select(self.tables.denormalize_table)
+        else:
+            stmt_type.fail()
+
+        if use_driver_cols:
+            execution_options = {"driver_column_names": True}
+        else:
+            execution_options = {}
+
+        if stmt_type.driver_sql:
+            row = connection.exec_driver_sql(
+                text_stmt, execution_options=execution_options
+            ).one()
+        else:
+            row = connection.execute(
+                stmt,
+                execution_options=execution_options,
+            ).one()
+
+        if (
+            stmt_type.core_select and not use_driver_cols
+        ) or not testing.requires.denormalized_names.enabled:
+            self._assert_row_mapping(
+                row,
+                {
+                    "id": 1,
+                    "all_lowercase": 5,
+                    "ALL_UPPERCASE": 6,
+                    "MixedCase": 7,
+                    "all_lowercase_quoted": 8,
+                    "ALL_UPPERCASE_QUOTED": 9,
+                },
+            )
+
+        if testing.requires.denormalized_names.enabled:
+            # with driver column names, raw cursor.description
+            # is used.  this is clearly not useful for non-quoted names.
+            if use_driver_cols:
+                self._assert_row_mapping(
+                    row,
+                    {
+                        "ID": 1,
+                        "ALL_LOWERCASE": 5,
+                        "ALL_UPPERCASE": 6,
+                        "MixedCase": 7,
+                        "all_lowercase_quoted": 8,
+                        "ALL_UPPERCASE_QUOTED": 9,
+                    },
+                )
+            else:
+                if stmt_type.core_select:
+                    self._assert_row_mapping(
+                        row,
+                        {
+                            "id": 1,
+                            "all_lowercase": 5,
+                            "ALL_UPPERCASE": 6,
+                            "MixedCase": 7,
+                            "all_lowercase_quoted": 8,
+                            "ALL_UPPERCASE_QUOTED": 9,
+                        },
+                        include_cols=self.tables.denormalize_table.c,
+                    )
+                else:
+                    self._assert_row_mapping(
+                        row,
+                        {
+                            "id": 1,
+                            "all_lowercase": 5,
+                            "all_uppercase": 6,
+                            "MixedCase": 7,
+                            "all_lowercase_quoted": 8,
+                            "all_uppercase_quoted": 9,
+                        },
+                        include_cols=(
+                            self.tables.denormalize_table.c
+                            if stmt_type.text_cols
+                            else None
+                        ),
+                    )
+
+        else:
+            self._assert_row_mapping(
+                row,
+                {
+                    "id": 1,
+                    "all_lowercase": 5,
+                    "ALL_UPPERCASE": 6,
+                    "MixedCase": 7,
+                    "all_lowercase_quoted": 8,
+                    "ALL_UPPERCASE_QUOTED": 9,
+                },
+                include_cols=(
+                    self.tables.denormalize_table.c
+                    if stmt_type.core_select or stmt_type.text_cols
+                    else None
+                ),
+            )
+
+
 class PercentSchemaNamesTest(fixtures.TablesTest):
     """tests using percent signs, spaces in table and column names.
 
index 36c6a74c27ef77e2ed3c6126649be130324db686..44cd1162bb87308a7a4337edef3a962021953380 100644 (file)
@@ -787,13 +787,20 @@ class UserDefinedRoundTripTest(_UserDefinedTypeFixture, fixtures.TablesTest):
             ),
         )
 
-    def test_processing(self, connection):
+    @testing.variation("use_driver_cols", [True, False])
+    def test_processing(self, connection, use_driver_cols):
         users = self.tables.users
         self._data_fixture(connection)
 
-        result = connection.execute(
-            users.select().order_by(users.c.user_id)
-        ).fetchall()
+        if use_driver_cols:
+            result = connection.execute(
+                users.select().order_by(users.c.user_id),
+                execution_options={"driver_column_names": True},
+            ).fetchall()
+        else:
+            result = connection.execute(
+                users.select().order_by(users.c.user_id)
+            ).fetchall()
         eq_(
             result,
             [
index 66209f50365948caf21f25722fbbcf5a390783c3..1c50845493cd5597f57232059356f6ff9a632ecb 100644 (file)
@@ -25,6 +25,7 @@ core_execution_options = {
     "max_row_buffer": "int",
     "yield_per": "int",
     "preserve_rowcount": "bool",
+    "driver_column_names": "bool",
 }
 
 orm_dql_execution_options = {