From: Mike Bayer Date: Thu, 21 Dec 2023 17:14:00 +0000 (-0500) Subject: use driver col names X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=b3105b7e3a9e6a5ff4771c1e9348eb551f4dd454;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git use driver col names Added new execution option :paramref:`_engine.Connection.execution_options.driver_column_names`. This option disables the "name normalize" step that takes place against the DBAPI ``cursor.description`` for uppercase-default backends like Oracle, and will cause the keys of a result set (e.g. named tuple names, dictionary keys in :attr:`.Row._mapping`, etc.) to be exactly what was delivered in cursor.description. This is mostly useful for plain textual statements using :func:`_sql.text` or :meth:`_engine.Connection.exec_driver_sql`. Fixes: #10789 Change-Id: Ib647b25bb53492fa839af04dd032d9f061e630af --- diff --git a/doc/build/changelog/unreleased_21/10789.rst b/doc/build/changelog/unreleased_21/10789.rst new file mode 100644 index 0000000000..af3b301b54 --- /dev/null +++ b/doc/build/changelog/unreleased_21/10789.rst @@ -0,0 +1,12 @@ +.. change:: + :tags: usecase, engine + :tickets: 10789 + + Added new execution option + :paramref:`_engine.Connection.execution_options.driver_column_names`. This + option disables the "name normalize" step that takes place against the + DBAPI ``cursor.description`` for uppercase-default backends like Oracle, + and will cause the keys of a result set (e.g. named tuple names, dictionary + keys in :attr:`.Row._mapping`, etc.) to be exactly what was delivered in + cursor.description. This is mostly useful for plain textual statements + using :func:`_sql.text` or :meth:`_engine.Connection.exec_driver_sql`. diff --git a/lib/sqlalchemy/dialects/oracle/base.py b/lib/sqlalchemy/dialects/oracle/base.py index 4f180cbd9e..dc347f0d79 100644 --- a/lib/sqlalchemy/dialects/oracle/base.py +++ b/lib/sqlalchemy/dialects/oracle/base.py @@ -50,7 +50,7 @@ The :class:`_schema.Identity` object support many options to control the incrementing value, etc. In addition to the standard options, Oracle supports setting :paramref:`_schema.Identity.always` to ``None`` to use the default -generated mode, rendering GENERATED AS IDENTITY in the DDL. +generated mode, rendering GENERATED AS IDENTITY in the DDL. Oracle also supports two custom options specified using dialect kwargs: * ``oracle_on_null``: when set to ``True`` renders ``ON NULL`` in conjunction diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py index 3451a82447..e5557f7d28 100644 --- a/lib/sqlalchemy/engine/base.py +++ b/lib/sqlalchemy/engine/base.py @@ -252,6 +252,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): insertmanyvalues_page_size: int = ..., schema_translate_map: Optional[SchemaTranslateMapType] = ..., preserve_rowcount: bool = False, + driver_column_names: bool = False, **opt: Any, ) -> Connection: ... @@ -515,6 +516,18 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): :ref:`orm_queryguide_execution_options` - documentation on all ORM-specific execution options + :param driver_column_names: When True, the returned + :class:`_engine.CursorResult` will use the column names as written in + ``cursor.description`` to set up the keys for the result set, + including the names of columns for the :class:`_engine.Row` object as + well as the dictionary keys when using :attr:`_engine.Row._mapping`. + On backends that use "name normalization" such as Oracle to correct + for lower case names being converted to all uppercase, this behavior + is turned off and the raw UPPERCASE names in cursor.description will + be present. + + .. versionadded:: 2.1 + """ # noqa if self._has_events or self.engine._has_events: self.dispatch.set_connection_execution_options(self, opt) diff --git a/lib/sqlalchemy/engine/cursor.py b/lib/sqlalchemy/engine/cursor.py index 3a58e71a93..9ff5cdeb86 100644 --- a/lib/sqlalchemy/engine/cursor.py +++ b/lib/sqlalchemy/engine/cursor.py @@ -187,7 +187,7 @@ class CursorResultMetaData(ResultMetaData): translated_indexes: Optional[List[int]], safe_for_cache: bool, keymap_by_result_column_idx: Any, - ) -> CursorResultMetaData: + ) -> Self: new_obj = self.__class__.__new__(self.__class__) new_obj._unpickled = unpickled new_obj._processors = processors @@ -200,7 +200,7 @@ class CursorResultMetaData(ResultMetaData): new_obj._key_to_index = self._make_key_to_index(keymap, MD_INDEX) return new_obj - def _remove_processors(self) -> CursorResultMetaData: + def _remove_processors(self) -> Self: assert not self._tuplefilter return self._make_new_metadata( unpickled=self._unpickled, @@ -216,9 +216,7 @@ class CursorResultMetaData(ResultMetaData): keymap_by_result_column_idx=self._keymap_by_result_column_idx, ) - def _splice_horizontally( - self, other: CursorResultMetaData - ) -> CursorResultMetaData: + def _splice_horizontally(self, other: CursorResultMetaData) -> Self: assert not self._tuplefilter keymap = dict(self._keymap) @@ -252,7 +250,7 @@ class CursorResultMetaData(ResultMetaData): }, ) - def _reduce(self, keys: Sequence[_KeyIndexType]) -> ResultMetaData: + def _reduce(self, keys: Sequence[_KeyIndexType]) -> Self: recs = list(self._metadata_for_keys(keys)) indexes = [rec[MD_INDEX] for rec in recs] @@ -284,7 +282,7 @@ class CursorResultMetaData(ResultMetaData): keymap_by_result_column_idx=self._keymap_by_result_column_idx, ) - def _adapt_to_context(self, context: ExecutionContext) -> ResultMetaData: + def _adapt_to_context(self, context: ExecutionContext) -> Self: """When using a cached Compiled construct that has a _result_map, for a new statement that used the cached Compiled, we need to ensure the keymap has the Column objects from our new statement as keys. @@ -350,6 +348,8 @@ class CursorResultMetaData(ResultMetaData): self, parent: CursorResult[Unpack[TupleAny]], cursor_description: _DBAPICursorDescription, + *, + driver_column_names: bool = False, ): context = parent.context self._tuplefilter = None @@ -383,6 +383,7 @@ class CursorResultMetaData(ResultMetaData): textual_ordered, ad_hoc_textual, loose_column_name_matching, + driver_column_names, ) # processors in key order which are used when building up @@ -474,15 +475,20 @@ class CursorResultMetaData(ResultMetaData): for metadata_entry in raw } - # update keymap with "translated" names. In SQLAlchemy this is a - # sqlite only thing, and in fact impacting only extremely old SQLite - # versions unlikely to be present in modern Python versions. - # however, the pyhive third party dialect is - # also using this hook, which means others still might use it as well. - # I dislike having this awkward hook here but as long as we need - # to use names in cursor.description in some cases we need to have - # some hook to accomplish this. - if not num_ctx_cols and context._translate_colname: + # update keymap with "translated" names. + # the "translated" name thing has a long history: + # 1. originally, it was used to fix an issue in very old SQLite + # versions prior to 3.10.0. This code is still there in the + # sqlite dialect. + # 2. Next, the pyhive third party dialect started using this hook + # for some driver related issue on their end. + # 3. Most recently, the "driver_column_names" execution option has + # taken advantage of this hook to get raw DBAPI col names in the + # result keys without disrupting the usual merge process. + + if driver_column_names or ( + not num_ctx_cols and context._translate_colname + ): self._keymap.update( { metadata_entry[MD_UNTRANSLATED]: self._keymap[ @@ -505,6 +511,7 @@ class CursorResultMetaData(ResultMetaData): textual_ordered, ad_hoc_textual, loose_column_name_matching, + driver_column_names, ): """Merge a cursor.description with compiled result column information. @@ -566,6 +573,7 @@ class CursorResultMetaData(ResultMetaData): and cols_are_ordered and not textual_ordered and num_ctx_cols == len(cursor_description) + and not driver_column_names ): self._keys = [elem[0] for elem in result_columns] # pure positional 1-1 case; doesn't need to read @@ -573,9 +581,11 @@ class CursorResultMetaData(ResultMetaData): # most common case for Core and ORM - # this metadata is safe to cache because we are guaranteed + # this metadata is safe to + # cache because we are guaranteed # to have the columns in the same order for new executions self._safe_for_cache = True + return [ ( idx, @@ -599,10 +609,13 @@ class CursorResultMetaData(ResultMetaData): if textual_ordered or ( ad_hoc_textual and len(cursor_description) == num_ctx_cols ): - self._safe_for_cache = True + self._safe_for_cache = not driver_column_names # textual positional case raw_iterator = self._merge_textual_cols_by_position( - context, cursor_description, result_columns + context, + cursor_description, + result_columns, + driver_column_names, ) elif num_ctx_cols: # compiled SQL with a mismatch of description cols @@ -615,13 +628,14 @@ class CursorResultMetaData(ResultMetaData): cursor_description, result_columns, loose_column_name_matching, + driver_column_names, ) else: # no compiled SQL, just a raw string, order of columns # can change for "select *" self._safe_for_cache = False raw_iterator = self._merge_cols_by_none( - context, cursor_description + context, cursor_description, driver_column_names ) return [ @@ -647,39 +661,53 @@ class CursorResultMetaData(ResultMetaData): ) in raw_iterator ] - def _colnames_from_description(self, context, cursor_description): + def _colnames_from_description( + self, context, cursor_description, driver_column_names + ): """Extract column names and data types from a cursor.description. Applies unicode decoding, column translation, "normalization", and case sensitivity rules to the names based on the dialect. """ - dialect = context.dialect translate_colname = context._translate_colname normalize_name = ( dialect.normalize_name if dialect.requires_name_normalize else None ) - untranslated = None self._keys = [] + untranslated = None + for idx, rec in enumerate(cursor_description): - colname = rec[0] + colname = unnormalized = rec[0] coltype = rec[1] if translate_colname: + # a None here for "untranslated" means "the dialect did not + # change the column name and the untranslated case can be + # ignored". otherwise "untranslated" is expected to be the + # original, unchanged colname (e.g. is == to "unnormalized") colname, untranslated = translate_colname(colname) + assert untranslated is None or untranslated == unnormalized + if normalize_name: colname = normalize_name(colname) - self._keys.append(colname) + if driver_column_names: + self._keys.append(unnormalized) - yield idx, colname, untranslated, coltype + yield idx, colname, unnormalized, coltype + + else: + self._keys.append(colname) + + yield idx, colname, untranslated, coltype def _merge_textual_cols_by_position( - self, context, cursor_description, result_columns + self, context, cursor_description, result_columns, driver_column_names ): num_ctx_cols = len(result_columns) @@ -696,7 +724,9 @@ class CursorResultMetaData(ResultMetaData): colname, untranslated, coltype, - ) in self._colnames_from_description(context, cursor_description): + ) in self._colnames_from_description( + context, cursor_description, driver_column_names + ): if idx < num_ctx_cols: ctx_rec = result_columns[idx] obj = ctx_rec[RM_OBJECTS] @@ -720,6 +750,7 @@ class CursorResultMetaData(ResultMetaData): cursor_description, result_columns, loose_column_name_matching, + driver_column_names, ): match_map = self._create_description_match_map( result_columns, loose_column_name_matching @@ -731,7 +762,9 @@ class CursorResultMetaData(ResultMetaData): colname, untranslated, coltype, - ) in self._colnames_from_description(context, cursor_description): + ) in self._colnames_from_description( + context, cursor_description, driver_column_names + ): try: ctx_rec = match_map[colname] except KeyError: @@ -771,6 +804,7 @@ class CursorResultMetaData(ResultMetaData): ] = {} for ridx, elem in enumerate(result_columns): key = elem[RM_RENDERED_NAME] + if key in d: # conflicting keyname - just add the column-linked objects # to the existing record. if there is a duplicate column @@ -794,13 +828,17 @@ class CursorResultMetaData(ResultMetaData): ) return d - def _merge_cols_by_none(self, context, cursor_description): + def _merge_cols_by_none( + self, context, cursor_description, driver_column_names + ): for ( idx, colname, untranslated, coltype, - ) in self._colnames_from_description(context, cursor_description): + ) in self._colnames_from_description( + context, cursor_description, driver_column_names + ): yield ( idx, None, @@ -1489,10 +1527,20 @@ class CursorResult(Result[Unpack[_Ts]]): self._metadata = self._no_result_metadata def _init_metadata(self, context, cursor_description): + driver_column_names = context.execution_options.get( + "driver_column_names", False + ) if context.compiled: compiled = context.compiled - if compiled._cached_metadata: + metadata: CursorResultMetaData + + if driver_column_names: + metadata = CursorResultMetaData( + self, cursor_description, driver_column_names=True + ) + assert not metadata._safe_for_cache + elif compiled._cached_metadata: metadata = compiled._cached_metadata else: metadata = CursorResultMetaData(self, cursor_description) @@ -1527,7 +1575,9 @@ class CursorResult(Result[Unpack[_Ts]]): else: self._metadata = metadata = CursorResultMetaData( - self, cursor_description + self, + cursor_description, + driver_column_names=driver_column_names, ) if self._echo: context.connection._log_debug( diff --git a/lib/sqlalchemy/engine/interfaces.py b/lib/sqlalchemy/engine/interfaces.py index d4c5aef797..52821b0ca1 100644 --- a/lib/sqlalchemy/engine/interfaces.py +++ b/lib/sqlalchemy/engine/interfaces.py @@ -271,6 +271,7 @@ class _CoreKnownExecutionOptions(TypedDict, total=False): insertmanyvalues_page_size: int schema_translate_map: Optional[SchemaTranslateMapType] preserve_rowcount: bool + driver_column_names: bool _ExecuteOptions = immutabledict[str, Any] diff --git a/lib/sqlalchemy/ext/asyncio/engine.py b/lib/sqlalchemy/ext/asyncio/engine.py index 16d14ef5db..0b572d426a 100644 --- a/lib/sqlalchemy/ext/asyncio/engine.py +++ b/lib/sqlalchemy/ext/asyncio/engine.py @@ -421,6 +421,7 @@ class AsyncConnection( insertmanyvalues_page_size: int = ..., schema_translate_map: Optional[SchemaTranslateMapType] = ..., preserve_rowcount: bool = False, + driver_column_names: bool = False, **opt: Any, ) -> AsyncConnection: ... diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index bfc0fb3652..b535b9db2d 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -1728,6 +1728,7 @@ class Query( stream_results: bool = False, max_row_buffer: int = ..., yield_per: int = ..., + driver_column_names: bool = ..., insertmanyvalues_page_size: int = ..., schema_translate_map: Optional[SchemaTranslateMapType] = ..., populate_existing: bool = False, diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py index 96a9337f48..dcb00e16a5 100644 --- a/lib/sqlalchemy/sql/base.py +++ b/lib/sqlalchemy/sql/base.py @@ -1157,6 +1157,7 @@ class Executable(roles.StatementRole): stream_results: bool = False, max_row_buffer: int = ..., yield_per: int = ..., + driver_column_names: bool = ..., insertmanyvalues_page_size: int = ..., schema_translate_map: Optional[SchemaTranslateMapType] = ..., populate_existing: bool = False, diff --git a/lib/sqlalchemy/testing/suite/test_results.py b/lib/sqlalchemy/testing/suite/test_results.py index b3f432fb76..05e35d0ebf 100644 --- a/lib/sqlalchemy/testing/suite/test_results.py +++ b/lib/sqlalchemy/testing/suite/test_results.py @@ -17,6 +17,7 @@ from ..schema import Table from ... import DateTime from ... import func from ... import Integer +from ... import quoted_name from ... import select from ... import sql from ... import String @@ -118,6 +119,165 @@ class RowFetchTest(fixtures.TablesTest): eq_(row.somelabel, datetime.datetime(2006, 5, 12, 12, 0, 0)) +class NameDenormalizeTest(fixtures.TablesTest): + __backend__ = True + + @classmethod + def define_tables(cls, metadata): + cls.tables.denormalize_table = Table( + "denormalize_table", + metadata, + Column("id", Integer, primary_key=True), + Column("all_lowercase", Integer), + Column("ALL_UPPERCASE", Integer), + Column("MixedCase", Integer), + Column(quoted_name("all_lowercase_quoted", quote=True), Integer), + Column(quoted_name("ALL_UPPERCASE_QUOTED", quote=True), Integer), + ) + + @classmethod + def insert_data(cls, connection): + connection.execute( + cls.tables.denormalize_table.insert(), + { + "id": 1, + "all_lowercase": 5, + "ALL_UPPERCASE": 6, + "MixedCase": 7, + "all_lowercase_quoted": 8, + "ALL_UPPERCASE_QUOTED": 9, + }, + ) + + def _assert_row_mapping(self, row, mapping, include_cols=None): + eq_(row._mapping, mapping) + + for k in mapping: + eq_(row._mapping[k], mapping[k]) + eq_(getattr(row, k), mapping[k]) + + for idx, k in enumerate(mapping): + eq_(row[idx], mapping[k]) + + if include_cols: + for col, (idx, k) in zip(include_cols, enumerate(mapping)): + eq_(row._mapping[col], mapping[k]) + + @testing.variation( + "stmt_type", ["driver_sql", "text_star", "core_select", "text_cols"] + ) + @testing.variation("use_driver_cols", [True, False]) + def test_cols_driver_cols(self, connection, stmt_type, use_driver_cols): + if stmt_type.driver_sql or stmt_type.text_star or stmt_type.text_cols: + stmt = select("*").select_from(self.tables.denormalize_table) + text_stmt = str(stmt.compile(connection)) + + if stmt_type.text_star or stmt_type.text_cols: + stmt = text(text_stmt) + + if stmt_type.text_cols: + stmt = stmt.columns(*self.tables.denormalize_table.c) + elif stmt_type.core_select: + stmt = select(self.tables.denormalize_table) + else: + stmt_type.fail() + + if use_driver_cols: + execution_options = {"driver_column_names": True} + else: + execution_options = {} + + if stmt_type.driver_sql: + row = connection.exec_driver_sql( + text_stmt, execution_options=execution_options + ).one() + else: + row = connection.execute( + stmt, + execution_options=execution_options, + ).one() + + if ( + stmt_type.core_select and not use_driver_cols + ) or not testing.requires.denormalized_names.enabled: + self._assert_row_mapping( + row, + { + "id": 1, + "all_lowercase": 5, + "ALL_UPPERCASE": 6, + "MixedCase": 7, + "all_lowercase_quoted": 8, + "ALL_UPPERCASE_QUOTED": 9, + }, + ) + + if testing.requires.denormalized_names.enabled: + # with driver column names, raw cursor.description + # is used. this is clearly not useful for non-quoted names. + if use_driver_cols: + self._assert_row_mapping( + row, + { + "ID": 1, + "ALL_LOWERCASE": 5, + "ALL_UPPERCASE": 6, + "MixedCase": 7, + "all_lowercase_quoted": 8, + "ALL_UPPERCASE_QUOTED": 9, + }, + ) + else: + if stmt_type.core_select: + self._assert_row_mapping( + row, + { + "id": 1, + "all_lowercase": 5, + "ALL_UPPERCASE": 6, + "MixedCase": 7, + "all_lowercase_quoted": 8, + "ALL_UPPERCASE_QUOTED": 9, + }, + include_cols=self.tables.denormalize_table.c, + ) + else: + self._assert_row_mapping( + row, + { + "id": 1, + "all_lowercase": 5, + "all_uppercase": 6, + "MixedCase": 7, + "all_lowercase_quoted": 8, + "all_uppercase_quoted": 9, + }, + include_cols=( + self.tables.denormalize_table.c + if stmt_type.text_cols + else None + ), + ) + + else: + self._assert_row_mapping( + row, + { + "id": 1, + "all_lowercase": 5, + "ALL_UPPERCASE": 6, + "MixedCase": 7, + "all_lowercase_quoted": 8, + "ALL_UPPERCASE_QUOTED": 9, + }, + include_cols=( + self.tables.denormalize_table.c + if stmt_type.core_select or stmt_type.text_cols + else None + ), + ) + + class PercentSchemaNamesTest(fixtures.TablesTest): """tests using percent signs, spaces in table and column names. diff --git a/test/sql/test_types.py b/test/sql/test_types.py index 36c6a74c27..44cd1162bb 100644 --- a/test/sql/test_types.py +++ b/test/sql/test_types.py @@ -787,13 +787,20 @@ class UserDefinedRoundTripTest(_UserDefinedTypeFixture, fixtures.TablesTest): ), ) - def test_processing(self, connection): + @testing.variation("use_driver_cols", [True, False]) + def test_processing(self, connection, use_driver_cols): users = self.tables.users self._data_fixture(connection) - result = connection.execute( - users.select().order_by(users.c.user_id) - ).fetchall() + if use_driver_cols: + result = connection.execute( + users.select().order_by(users.c.user_id), + execution_options={"driver_column_names": True}, + ).fetchall() + else: + result = connection.execute( + users.select().order_by(users.c.user_id) + ).fetchall() eq_( result, [ diff --git a/test/typing/test_overloads.py b/test/typing/test_overloads.py index 66209f5036..1c50845493 100644 --- a/test/typing/test_overloads.py +++ b/test/typing/test_overloads.py @@ -25,6 +25,7 @@ core_execution_options = { "max_row_buffer": "int", "yield_per": "int", "preserve_rowcount": "bool", + "driver_column_names": "bool", } orm_dql_execution_options = {