From 2804f49e1922d3cdc67065b733d0dd6e06554905 Mon Sep 17 00:00:00 2001 From: Denis Laxalde Date: Fri, 25 Jul 2025 10:16:00 +0200 Subject: [PATCH] Complete type annotations of sqlalchemy.engine.cursor module Related to #6810. --- lib/sqlalchemy/engine/cursor.py | 353 +++++++++++++++++++++---------- lib/sqlalchemy/engine/default.py | 13 +- lib/sqlalchemy/sql/base.py | 2 +- 3 files changed, 248 insertions(+), 120 deletions(-) diff --git a/lib/sqlalchemy/engine/cursor.py b/lib/sqlalchemy/engine/cursor.py index 165ae2feaa..269585eb03 100644 --- a/lib/sqlalchemy/engine/cursor.py +++ b/lib/sqlalchemy/engine/cursor.py @@ -4,7 +4,6 @@ # # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php -# mypy: allow-untyped-defs, allow-untyped-calls """Define cursor-specific result set constructs including :class:`.CursorResult`.""" @@ -19,6 +18,7 @@ import typing from typing import Any from typing import cast from typing import ClassVar +from typing import Deque from typing import Dict from typing import Iterable from typing import Iterator @@ -61,7 +61,10 @@ if typing.TYPE_CHECKING: from .base import Connection from .default import DefaultExecutionContext from .interfaces import _DBAPICursorDescription + from .interfaces import _MutableCoreSingleExecuteParams + from .interfaces import CoreExecuteOptionsParameter from .interfaces import DBAPICursor + from .interfaces import DBAPIType from .interfaces import Dialect from .interfaces import ExecutionContext from .result import _KeyIndexType @@ -70,6 +73,8 @@ if typing.TYPE_CHECKING: from .result import _KeyType from .result import _ProcessorsType from .result import _TupleGetterType + from ..sql.schema import Column + from ..sql.sqltypes import NullType from ..sql.type_api import _ResultProcessorType @@ -119,7 +124,7 @@ MD_UNTRANSLATED: Literal[6] = 6 _CursorKeyMapRecType = Tuple[ Optional[int], # MD_INDEX, None means the record is ambiguously named int, # MD_RESULT_MAP_INDEX - List[Any], # MD_OBJECTS + TupleAny, # MD_OBJECTS str, # MD_LOOKUP_KEY str, # MD_RENDERED_NAME Optional["_ResultProcessorType[Any]"], # MD_PROCESSOR @@ -428,7 +433,7 @@ class CursorResultMetaData(ResultMetaData): # column keys and other names if num_ctx_cols: # keymap by primary string... - by_key = { + by_key: Dict[_KeyType, Any] = { metadata_entry[MD_LOOKUP_KEY]: metadata_entry for metadata_entry in raw } @@ -474,7 +479,7 @@ class CursorResultMetaData(ResultMetaData): # record into by_key. by_key.update( { - key: (None, None, [], key, key, None, None) + key: (None, None, (), key, key, None, None) for key in dupes } ) @@ -516,7 +521,7 @@ class CursorResultMetaData(ResultMetaData): ): self._keymap.update( { - metadata_entry[MD_UNTRANSLATED]: self._keymap[ + metadata_entry[MD_UNTRANSLATED]: self._keymap[ # type: ignore[misc] # noqa: E501 metadata_entry[MD_LOOKUP_KEY] ] for metadata_entry in raw @@ -528,16 +533,16 @@ class CursorResultMetaData(ResultMetaData): def _merge_cursor_description( self, - context, - cursor_description, - result_columns, - num_ctx_cols, - cols_are_ordered, - textual_ordered, - ad_hoc_textual, - loose_column_name_matching, - driver_column_names, - ): + context: DefaultExecutionContext, + cursor_description: _DBAPICursorDescription, + result_columns: Sequence[ResultColumnsEntry], + num_ctx_cols: int, + cols_are_ordered: bool, + textual_ordered: bool, + ad_hoc_textual: bool, + loose_column_name_matching: bool, + driver_column_names: bool, + ) -> List[_CursorKeyMapRecType]: """Merge a cursor.description with compiled result column information. There are at least four separate strategies used here, selected @@ -674,7 +679,7 @@ class CursorResultMetaData(ResultMetaData): mapped_type, cursor_colname, coltype ), untranslated, - ) + ) # type: ignore[misc] for ( idx, ridx, @@ -687,8 +692,11 @@ class CursorResultMetaData(ResultMetaData): ] def _colnames_from_description( - self, context, cursor_description, driver_column_names - ): + self, + context: DefaultExecutionContext, + cursor_description: _DBAPICursorDescription, + driver_column_names: bool, + ) -> Iterator[Tuple[int, str, str, Optional[str], DBAPIType]]: """Extract column names and data types from a cursor.description. Applies unicode decoding, column translation, "normalization", @@ -698,7 +706,7 @@ class CursorResultMetaData(ResultMetaData): dialect = context.dialect translate_colname = context._translate_colname normalize_name = ( - dialect.normalize_name if dialect.requires_name_normalize else None + dialect.normalize_name if dialect.requires_name_normalize else None # type: ignore[attr-defined] # noqa: E501 ) untranslated = None @@ -726,8 +734,22 @@ class CursorResultMetaData(ResultMetaData): yield idx, colname, unnormalized, untranslated, coltype def _merge_textual_cols_by_position( - self, context, cursor_description, result_columns, driver_column_names - ): + self, + context: DefaultExecutionContext, + cursor_description: _DBAPICursorDescription, + result_columns: Sequence[ResultColumnsEntry], + driver_column_names: bool, + ) -> Iterator[ + Tuple[ + int, + Optional[int], + str, + TypeEngine[Any], + DBAPIType, + Optional[TupleAny], + Optional[str], + ] + ]: num_ctx_cols = len(result_columns) if num_ctx_cols > len(cursor_description): @@ -740,7 +762,7 @@ class CursorResultMetaData(ResultMetaData): self._keys = [] - uses_denormalize = context.dialect.requires_name_normalize + uses_denormalize = context.dialect.requires_name_normalize # type: ignore[attr-defined] # noqa: E501 for ( idx, colname, @@ -801,12 +823,22 @@ class CursorResultMetaData(ResultMetaData): def _merge_cols_by_name( self, - context, - cursor_description, - result_columns, - loose_column_name_matching, - driver_column_names, - ): + context: DefaultExecutionContext, + cursor_description: _DBAPICursorDescription, + result_columns: Sequence[ResultColumnsEntry], + loose_column_name_matching: bool, + driver_column_names: bool, + ) -> Iterator[ + Tuple[ + int, + Optional[int], + str, + Union[NullType, TypeEngine[Any]], + DBAPIType, + Optional[TupleAny], + Optional[str], + ] + ]: match_map = self._create_description_match_map( result_columns, loose_column_name_matching ) @@ -852,11 +884,9 @@ class CursorResultMetaData(ResultMetaData): @classmethod def _create_description_match_map( cls, - result_columns: List[ResultColumnsEntry], + result_columns: Sequence[ResultColumnsEntry], loose_column_name_matching: bool = False, - ) -> Dict[ - Union[str, object], Tuple[str, Tuple[Any, ...], TypeEngine[Any], int] - ]: + ) -> Dict[Union[str, object], Tuple[str, TupleAny, TypeEngine[Any], int]]: """when matching cursor.description to a set of names that are present in a Compiled object, as is the case with TextualSelect, get all the names we expect might match those in cursor.description. @@ -864,7 +894,7 @@ class CursorResultMetaData(ResultMetaData): d: Dict[ Union[str, object], - Tuple[str, Tuple[Any, ...], TypeEngine[Any], int], + Tuple[str, TupleAny, TypeEngine[Any], int], ] = {} for ridx, elem in enumerate(result_columns): key = elem[RM_RENDERED_NAME] @@ -893,8 +923,13 @@ class CursorResultMetaData(ResultMetaData): return d def _merge_cols_by_none( - self, context, cursor_description, driver_column_names - ): + self, + context: DefaultExecutionContext, + cursor_description: _DBAPICursorDescription, + driver_column_names: bool, + ) -> Iterator[ + Tuple[int, None, str, NullType, DBAPIType, None, Optional[str]] + ]: self._keys = [] for ( @@ -942,13 +977,17 @@ class CursorResultMetaData(ResultMetaData): else: return None - def _raise_for_ambiguous_column_name(self, rec): + def _raise_for_ambiguous_column_name( + self, rec: _KeyMapRecType + ) -> NoReturn: raise exc.InvalidRequestError( "Ambiguous column name '%s' in " "result set column descriptions" % rec[MD_LOOKUP_KEY] ) - def _index_for_key(self, key: Any, raiseerr: bool = True) -> Optional[int]: + def _index_for_key( + self, key: _KeyIndexType, raiseerr: bool = True + ) -> Optional[int]: # TODO: can consider pre-loading ints and negative ints # into _keymap - also no coverage here if isinstance(key, int): @@ -967,15 +1006,17 @@ class CursorResultMetaData(ResultMetaData): self._raise_for_ambiguous_column_name(rec) return index - def _indexes_for_keys(self, keys): + def _indexes_for_keys( + self, keys: Sequence[_KeyIndexType] + ) -> Sequence[int]: try: - return [self._keymap[key][0] for key in keys] + return [self._keymap[key][0] for key in keys] # type: ignore[index,misc] # noqa: E501 except KeyError as ke: # ensure it raises CursorResultMetaData._key_fallback(self, ke.args[0], ke) def _metadata_for_keys( - self, keys: Sequence[Any] + self, keys: Sequence[_KeyIndexType] ) -> Iterator[_NonAmbigCursorKeyMapRecType]: for key in keys: if isinstance(key, int): @@ -994,7 +1035,7 @@ class CursorResultMetaData(ResultMetaData): yield cast(_NonAmbigCursorKeyMapRecType, rec) - def __getstate__(self): + def __getstate__(self) -> Dict[str, Any]: # TODO: consider serializing this as SimpleResultMetaData return { "_keymap": { @@ -1014,7 +1055,7 @@ class CursorResultMetaData(ResultMetaData): "_translated_indexes": self._translated_indexes, } - def __setstate__(self, state): + def __setstate__(self, state: Dict[str, Any]) -> None: self._processors = [None for _ in range(len(state["_keys"]))] self._keymap = state["_keymap"] self._keymap_by_result_column_idx = None @@ -1060,7 +1101,7 @@ class ResultFetchStrategy: def yield_per( self, result: CursorResult[Unpack[TupleAny]], - dbapi_cursor: Optional[DBAPICursor], + dbapi_cursor: DBAPICursor, num: int, ) -> None: return @@ -1108,22 +1149,47 @@ class NoCursorFetchStrategy(ResultFetchStrategy): __slots__ = () - def soft_close(self, result, dbapi_cursor): + def soft_close( + self, + result: CursorResult[Unpack[TupleAny]], + dbapi_cursor: Optional[DBAPICursor], + ) -> None: pass - def hard_close(self, result, dbapi_cursor): + def hard_close( + self, + result: CursorResult[Unpack[TupleAny]], + dbapi_cursor: Optional[DBAPICursor], + ) -> None: pass - def fetchone(self, result, dbapi_cursor, hard_close=False): + def fetchone( + self, + result: CursorResult[Unpack[TupleAny]], + dbapi_cursor: DBAPICursor, + hard_close: bool = False, + ) -> Any: return self._non_result(result, None) - def fetchmany(self, result, dbapi_cursor, size=None): + def fetchmany( + self, + result: CursorResult[Unpack[TupleAny]], + dbapi_cursor: DBAPICursor, + size: Optional[int] = None, + ) -> Any: return self._non_result(result, []) - def fetchall(self, result, dbapi_cursor): + def fetchall( + self, result: CursorResult[Unpack[TupleAny]], dbapi_cursor: DBAPICursor + ) -> Any: return self._non_result(result, []) - def _non_result(self, result, default, err=None): + def _non_result( + self, + result: CursorResult[Unpack[TupleAny]], + default: Any, + err: Optional[BaseException] = None, + ) -> Any: raise NotImplementedError() @@ -1140,7 +1206,12 @@ class NoCursorDQLFetchStrategy(NoCursorFetchStrategy): __slots__ = () - def _non_result(self, result, default, err=None): + def _non_result( + self, + result: CursorResult[Unpack[TupleAny]], + default: Any, + err: Optional[BaseException] = None, + ) -> Any: if result.closed: raise exc.ResourceClosedError( "This result object is closed." @@ -1162,10 +1233,15 @@ class NoCursorDMLFetchStrategy(NoCursorFetchStrategy): __slots__ = () - def _non_result(self, result, default, err=None): + def _non_result( + self, + result: CursorResult[Unpack[TupleAny]], + default: Any, + err: Optional[BaseException] = None, + ) -> Any: # we only expect to have a _NoResultMetaData() here right now. assert not result._metadata.returns_rows - result._metadata._we_dont_return_rows(err) + result._metadata._we_dont_return_rows(err) # type: ignore[union-attr] _NO_CURSOR_DML = NoCursorDMLFetchStrategy() @@ -1202,10 +1278,7 @@ class CursorFetchStrategy(ResultFetchStrategy): ) def yield_per( - self, - result: CursorResult[Any], - dbapi_cursor: Optional[DBAPICursor], - num: int, + self, result: CursorResult[Any], dbapi_cursor: DBAPICursor, num: int ) -> None: result.cursor_strategy = BufferedRowCursorFetchStrategy( dbapi_cursor, @@ -1294,11 +1367,11 @@ class BufferedRowCursorFetchStrategy(CursorFetchStrategy): def __init__( self, - dbapi_cursor, - execution_options, - growth_factor=5, - initial_buffer=None, - ): + dbapi_cursor: DBAPICursor, + execution_options: CoreExecuteOptionsParameter, + growth_factor: int = 5, + initial_buffer: Optional[Deque[Any]] = None, + ) -> None: self._max_row_buffer = execution_options.get("max_row_buffer", 1000) if initial_buffer is not None: @@ -1313,13 +1386,17 @@ class BufferedRowCursorFetchStrategy(CursorFetchStrategy): self._bufsize = self._max_row_buffer @classmethod - def create(cls, result): + def create( + cls, result: CursorResult[Any] + ) -> BufferedRowCursorFetchStrategy: return BufferedRowCursorFetchStrategy( result.cursor, result.context.execution_options, ) - def _buffer_rows(self, result, dbapi_cursor): + def _buffer_rows( + self, result: CursorResult[Any], dbapi_cursor: DBAPICursor + ) -> None: """this is currently used only by fetchone().""" size = self._bufsize @@ -1339,19 +1416,30 @@ class BufferedRowCursorFetchStrategy(CursorFetchStrategy): self._max_row_buffer, size * self._growth_factor ) - def yield_per(self, result, dbapi_cursor, num): + def yield_per( + self, result: CursorResult[Any], dbapi_cursor: DBAPICursor, num: int + ) -> None: self._growth_factor = 0 self._max_row_buffer = self._bufsize = num - def soft_close(self, result, dbapi_cursor): + def soft_close( + self, result: CursorResult[Any], dbapi_cursor: Optional[DBAPICursor] + ) -> None: self._rowbuffer.clear() super().soft_close(result, dbapi_cursor) - def hard_close(self, result, dbapi_cursor): + def hard_close( + self, result: CursorResult[Any], dbapi_cursor: Optional[DBAPICursor] + ) -> None: self._rowbuffer.clear() super().hard_close(result, dbapi_cursor) - def fetchone(self, result, dbapi_cursor, hard_close=False): + def fetchone( + self, + result: CursorResult[Any], + dbapi_cursor: DBAPICursor, + hard_close: bool = False, + ) -> Any: if not self._rowbuffer: self._buffer_rows(result, dbapi_cursor) if not self._rowbuffer: @@ -1362,7 +1450,12 @@ class BufferedRowCursorFetchStrategy(CursorFetchStrategy): return None return self._rowbuffer.popleft() - def fetchmany(self, result, dbapi_cursor, size=None): + def fetchmany( + self, + result: CursorResult[Any], + dbapi_cursor: DBAPICursor, + size: Optional[int] = None, + ) -> Any: if size is None: return self.fetchall(result, dbapi_cursor) @@ -1386,7 +1479,9 @@ class BufferedRowCursorFetchStrategy(CursorFetchStrategy): result._soft_close() return res - def fetchall(self, result, dbapi_cursor): + def fetchall( + self, result: CursorResult[Any], dbapi_cursor: DBAPICursor + ) -> Any: try: ret = list(self._rowbuffer) + list(dbapi_cursor.fetchall()) self._rowbuffer.clear() @@ -1420,25 +1515,41 @@ class FullyBufferedCursorFetchStrategy(CursorFetchStrategy): assert dbapi_cursor is not None self._rowbuffer = collections.deque(dbapi_cursor.fetchall()) - def yield_per(self, result, dbapi_cursor, num): + def yield_per( + self, result: CursorResult[Any], dbapi_cursor: DBAPICursor, num: int + ) -> Any: pass - def soft_close(self, result, dbapi_cursor): + def soft_close( + self, result: CursorResult[Any], dbapi_cursor: Optional[DBAPICursor] + ) -> None: self._rowbuffer.clear() super().soft_close(result, dbapi_cursor) - def hard_close(self, result, dbapi_cursor): + def hard_close( + self, result: CursorResult[Any], dbapi_cursor: Optional[DBAPICursor] + ) -> None: self._rowbuffer.clear() super().hard_close(result, dbapi_cursor) - def fetchone(self, result, dbapi_cursor, hard_close=False): + def fetchone( + self, + result: CursorResult[Any], + dbapi_cursor: DBAPICursor, + hard_close: bool = False, + ) -> Any: if self._rowbuffer: return self._rowbuffer.popleft() else: result._soft_close(hard=hard_close) return None - def fetchmany(self, result, dbapi_cursor, size=None): + def fetchmany( + self, + result: CursorResult[Any], + dbapi_cursor: DBAPICursor, + size: Optional[int] = None, + ) -> Any: if size is None: return self.fetchall(result, dbapi_cursor) @@ -1448,7 +1559,9 @@ class FullyBufferedCursorFetchStrategy(CursorFetchStrategy): result._soft_close() return rows - def fetchall(self, result, dbapi_cursor): + def fetchall( + self, result: CursorResult[Any], dbapi_cursor: DBAPICursor + ) -> Any: ret = self._rowbuffer self._rowbuffer = collections.deque() result._soft_close() @@ -1460,35 +1573,37 @@ class _NoResultMetaData(ResultMetaData): returns_rows = False - def _we_dont_return_rows(self, err=None): + def _we_dont_return_rows( + self, err: Optional[BaseException] = None + ) -> NoReturn: raise exc.ResourceClosedError( "This result object does not return rows. " "It has been closed automatically." ) from err - def _index_for_key(self, keys, raiseerr): + def _index_for_key(self, keys: _KeyIndexType, raiseerr: bool) -> NoReturn: self._we_dont_return_rows() - def _metadata_for_keys(self, key): + def _metadata_for_keys(self, keys: Sequence[_KeyIndexType]) -> NoReturn: self._we_dont_return_rows() - def _reduce(self, keys): + def _reduce(self, keys: Sequence[_KeyIndexType]) -> NoReturn: self._we_dont_return_rows() @property - def _keymap(self): # type: ignore[override] + def _keymap(self) -> NoReturn: # type: ignore[override] self._we_dont_return_rows() @property - def _key_to_index(self): # type: ignore[override] + def _key_to_index(self) -> NoReturn: # type: ignore[override] self._we_dont_return_rows() @property - def _processors(self): # type: ignore[override] + def _processors(self) -> NoReturn: # type: ignore[override] self._we_dont_return_rows() @property - def keys(self): + def keys(self) -> NoReturn: self._we_dont_return_rows() @@ -1576,7 +1691,7 @@ class CursorResult(Result[Unpack[_Ts]]): if tf is not None: _fixed_tf = tf # needed to make mypy happy... - def _sliced_row(raw_data): + def _sliced_row(raw_data: Any) -> Any: return _make_row(_fixed_tf(raw_data)) sliced_row = _sliced_row @@ -1586,18 +1701,18 @@ class CursorResult(Result[Unpack[_Ts]]): if echo: log = self.context.connection._log_debug - def _log_row(row): + def _log_row(row: Any) -> Any: log("Row %r", sql_util._repr_row(row)) return row self._row_logging_fn = _log_row - def _make_row_2(row): + def _make_row_2(row: Any) -> Any: return _log_row(sliced_row(row)) make_row = _make_row_2 else: - make_row = sliced_row + make_row = sliced_row # type: ignore[assignment] self._set_memoized_attribute("_row_getter", make_row) else: @@ -1678,7 +1793,7 @@ class CursorResult(Result[Unpack[_Ts]]): ) return metadata - def _soft_close(self, hard=False): + def _soft_close(self, hard: bool = False) -> None: """Soft close this :class:`_engine.CursorResult`. This releases all DBAPI cursor resources, but leaves the @@ -1716,7 +1831,7 @@ class CursorResult(Result[Unpack[_Ts]]): self._soft_closed = True @property - def inserted_primary_key_rows(self): + def inserted_primary_key_rows(self) -> List[Optional[Any]]: """Return the value of :attr:`_engine.CursorResult.inserted_primary_key` as a row contained within a list; some dialects may support a @@ -1775,10 +1890,10 @@ class CursorResult(Result[Unpack[_Ts]]): "when returning() " "is used." ) - return self.context.inserted_primary_key_rows + return self.context.inserted_primary_key_rows # type: ignore[no-any-return] # noqa: E501 @property - def inserted_primary_key(self): + def inserted_primary_key(self) -> Optional[Any]: """Return the primary key for the row just inserted. The return value is a :class:`_result.Row` object representing @@ -1823,7 +1938,11 @@ class CursorResult(Result[Unpack[_Ts]]): else: return None - def last_updated_params(self): + def last_updated_params( + self, + ) -> Union[ + List[_MutableCoreSingleExecuteParams], _MutableCoreSingleExecuteParams + ]: """Return the collection of updated parameters from this execution. @@ -1845,7 +1964,11 @@ class CursorResult(Result[Unpack[_Ts]]): else: return self.context.compiled_parameters[0] - def last_inserted_params(self): + def last_inserted_params( + self, + ) -> Union[ + List[_MutableCoreSingleExecuteParams], _MutableCoreSingleExecuteParams + ]: """Return the collection of inserted parameters from this execution. @@ -1868,7 +1991,9 @@ class CursorResult(Result[Unpack[_Ts]]): return self.context.compiled_parameters[0] @property - def returned_defaults_rows(self): + def returned_defaults_rows( + self, + ) -> Optional[Sequence[Row[Unpack[TupleAny]]]]: """Return a list of rows each containing the values of default columns that were fetched using the :meth:`.ValuesBase.return_defaults` feature. @@ -1880,9 +2005,7 @@ class CursorResult(Result[Unpack[_Ts]]): """ return self.context.returned_default_rows - def splice_horizontally( - self, other: CursorResult[Any] - ) -> CursorResult[Any]: + def splice_horizontally(self, other: CursorResult[Any]) -> Self: """Return a new :class:`.CursorResult` that "horizontally splices" together the rows of this :class:`.CursorResult` with that of another :class:`.CursorResult`. @@ -1937,7 +2060,7 @@ class CursorResult(Result[Unpack[_Ts]]): """ # noqa: E501 - clone: CursorResult[Any] = self._generate() + clone = self._generate() assert clone is self # just to note assert isinstance(other._metadata, CursorResultMetaData) assert isinstance(self._metadata, CursorResultMetaData) @@ -1961,7 +2084,7 @@ class CursorResult(Result[Unpack[_Ts]]): clone._reset_memoizations() return clone - def splice_vertically(self, other): + def splice_vertically(self, other: CursorResult[Any]) -> Self: """Return a new :class:`.CursorResult` that "vertically splices", i.e. "extends", the rows of this :class:`.CursorResult` with that of another :class:`.CursorResult`. @@ -1993,7 +2116,7 @@ class CursorResult(Result[Unpack[_Ts]]): clone._reset_memoizations() return clone - def _rewind(self, rows): + def _rewind(self, rows: Any) -> Self: """rewind this result back to the given rowset. this is used internally for the case where an :class:`.Insert` @@ -2029,7 +2152,7 @@ class CursorResult(Result[Unpack[_Ts]]): return self @property - def returned_defaults(self): + def returned_defaults(self) -> Optional[Row[Unpack[TupleAny]]]: """Return the values of default columns that were fetched using the :meth:`.ValuesBase.return_defaults` feature. @@ -2055,7 +2178,7 @@ class CursorResult(Result[Unpack[_Ts]]): else: return None - def lastrow_has_defaults(self): + def lastrow_has_defaults(self) -> bool: """Return ``lastrow_has_defaults()`` from the underlying :class:`.ExecutionContext`. @@ -2065,7 +2188,7 @@ class CursorResult(Result[Unpack[_Ts]]): return self.context.lastrow_has_defaults() - def postfetch_cols(self): + def postfetch_cols(self) -> Optional[Sequence[Column[Any]]]: """Return ``postfetch_cols()`` from the underlying :class:`.ExecutionContext`. @@ -2088,7 +2211,7 @@ class CursorResult(Result[Unpack[_Ts]]): ) return self.context.postfetch_cols - def prefetch_cols(self): + def prefetch_cols(self) -> Optional[Sequence[Column[Any]]]: """Return ``prefetch_cols()`` from the underlying :class:`.ExecutionContext`. @@ -2111,7 +2234,7 @@ class CursorResult(Result[Unpack[_Ts]]): ) return self.context.prefetch_cols - def supports_sane_rowcount(self): + def supports_sane_rowcount(self) -> bool: """Return ``supports_sane_rowcount`` from the dialect. See :attr:`_engine.CursorResult.rowcount` for background. @@ -2120,7 +2243,7 @@ class CursorResult(Result[Unpack[_Ts]]): return self.dialect.supports_sane_rowcount - def supports_sane_multi_rowcount(self): + def supports_sane_multi_rowcount(self) -> bool: """Return ``supports_sane_multi_rowcount`` from the dialect. See :attr:`_engine.CursorResult.rowcount` for background. @@ -2212,7 +2335,7 @@ class CursorResult(Result[Unpack[_Ts]]): raise # not called @property - def lastrowid(self): + def lastrowid(self) -> int: """Return the 'lastrowid' accessor on the DBAPI cursor. This is a DBAPI specific method and is only functional @@ -2233,7 +2356,7 @@ class CursorResult(Result[Unpack[_Ts]]): self.cursor_strategy.handle_exception(self, self.cursor, e) @property - def returns_rows(self): + def returns_rows(self) -> bool: """True if this :class:`_engine.CursorResult` returns zero or more rows. @@ -2261,7 +2384,7 @@ class CursorResult(Result[Unpack[_Ts]]): return self._metadata.returns_rows @property - def is_insert(self): + def is_insert(self) -> bool: """True if this :class:`_engine.CursorResult` is the result of a executing an expression language compiled :func:`_expression.insert` construct. @@ -2274,7 +2397,7 @@ class CursorResult(Result[Unpack[_Ts]]): """ return self.context.isinsert - def _fetchiter_impl(self): + def _fetchiter_impl(self) -> Iterator[Any]: fetchone = self.cursor_strategy.fetchone while True: @@ -2283,16 +2406,16 @@ class CursorResult(Result[Unpack[_Ts]]): break yield row - def _fetchone_impl(self, hard_close=False): + def _fetchone_impl(self, hard_close: bool = False) -> Any: return self.cursor_strategy.fetchone(self, self.cursor, hard_close) - def _fetchall_impl(self): + def _fetchall_impl(self) -> Any: return self.cursor_strategy.fetchall(self, self.cursor) - def _fetchmany_impl(self, size=None): + def _fetchmany_impl(self, size: Optional[int] = None) -> Any: return self.cursor_strategy.fetchmany(self, self.cursor, size) - def _raw_row_iterator(self): + def _raw_row_iterator(self) -> Any: return self._fetchiter_impl() def merge( @@ -2306,7 +2429,7 @@ class CursorResult(Result[Unpack[_Ts]]): ) return merged_result - def close(self) -> Any: + def close(self) -> None: """Close this :class:`_engine.CursorResult`. This closes out the underlying DBAPI cursor corresponding to the diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py index c8bdb56635..5cbe11dd13 100644 --- a/lib/sqlalchemy/engine/default.py +++ b/lib/sqlalchemy/engine/default.py @@ -87,6 +87,7 @@ if typing.TYPE_CHECKING: from .interfaces import ConnectArgsType from .interfaces import DBAPIConnection from .interfaces import DBAPIModule + from .interfaces import DBAPIType from .interfaces import IsolationLevel from .row import Row from .url import URL @@ -1235,7 +1236,9 @@ class DefaultExecutionContext(ExecutionContext): # a hook for SQLite's translation of # result column names # NOTE: pyhive is using this hook, can't remove it :( - _translate_colname: Optional[Callable[[str], str]] = None + _translate_colname: Optional[ + Callable[[str], Tuple[str, Optional[str]]] + ] = None _expanded_parameters: Mapping[str, List[str]] = util.immutabledict() """used by set_input_sizes(). @@ -1791,7 +1794,9 @@ class DefaultExecutionContext(ExecutionContext): def post_exec(self): pass - def get_result_processor(self, type_, colname, coltype): + def get_result_processor( + self, type_: TypeEngine[Any], colname: str, coltype: DBAPIType + ) -> Optional[_ResultProcessorType[Any]]: """Return a 'result processor' for a given type as present in cursor.description. @@ -1801,7 +1806,7 @@ class DefaultExecutionContext(ExecutionContext): """ return type_._cached_result_processor(self.dialect, coltype) - def get_lastrowid(self): + def get_lastrowid(self) -> int: """return self.cursor.lastrowid, or equivalent, after an INSERT. This may involve calling special cursor functions, issuing a new SELECT @@ -2055,7 +2060,7 @@ class DefaultExecutionContext(ExecutionContext): getter(row, param) for row, param in zip(rows, compiled_params) ] - def lastrow_has_defaults(self): + def lastrow_has_defaults(self) -> bool: return (self.isinsert or self.isupdate) and bool( cast(SQLCompiler, self.compiled).postfetch ) diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py index fc67892471..62fca371d8 100644 --- a/lib/sqlalchemy/sql/base.py +++ b/lib/sqlalchemy/sql/base.py @@ -780,7 +780,7 @@ class InPlaceGenerative(HasMemoized): __slots__ = () - def _generate(self): + def _generate(self) -> Self: skip = self._memoized_keys # note __dict__ needs to be in __slots__ if this is used for k in skip: -- 2.47.3