From: KapilDagur Date: Mon, 14 Jul 2025 19:36:30 +0000 (-0400) Subject: typing: improve type coverage in sql.base X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=e47d652653e920c459f6db9d3ff7db469b27c41c;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git typing: improve type coverage in sql.base improve type coverage in `sqlalchemy.sql.base` References: #6810 Closes: #12707 Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/12707 Pull-request-sha: 7374212dfc88a43f381f8380d9b4ac193f5ed10b Change-Id: Ied0676f420bc27ae033f0a5e6e22d806d20f4404 --- diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py index 73f8091984..f428dff185 100644 --- a/lib/sqlalchemy/sql/base.py +++ b/lib/sqlalchemy/sql/base.py @@ -21,7 +21,9 @@ from typing import Any from typing import Callable from typing import cast from typing import Dict +from typing import Final from typing import FrozenSet +from typing import Generator from typing import Generic from typing import Iterable from typing import Iterator @@ -67,6 +69,8 @@ if TYPE_CHECKING: from ._orm_types import DMLStrategyArgument from ._orm_types import SynchronizeSessionArgument from ._typing import _CLE + from ._util_cy import anon_map + from .cache_key import CacheKey from .compiler import SQLCompiler from .dml import Delete from .dml import Insert @@ -115,7 +119,7 @@ class _NoArg(Enum): return f"_NoArg.{self.name}" -NO_ARG = _NoArg.NO_ARG +NO_ARG: Final = _NoArg.NO_ARG class _NoneName(Enum): @@ -123,7 +127,7 @@ class _NoneName(Enum): """indicate a 'deferred' name that was ultimately the value None.""" -_NONE_NAME = _NoneName.NONE_NAME +_NONE_NAME: Final = _NoneName.NONE_NAME _T = TypeVar("_T", bound=Any) @@ -158,7 +162,9 @@ class _DefaultDescriptionTuple(NamedTuple): ) -_never_select_column = operator.attrgetter("_omit_from_statements") +_never_select_column: operator.attrgetter[Any] = operator.attrgetter( + "_omit_from_statements" +) class _EntityNamespace(Protocol): @@ -193,12 +199,12 @@ class Immutable: __slots__ = () - _is_immutable = True + _is_immutable: bool = True - def unique_params(self, *optionaldict, **kwargs): + def unique_params(self, *optionaldict: Any, **kwargs: Any) -> NoReturn: raise NotImplementedError("Immutable objects do not support copying") - def params(self, *optionaldict, **kwargs): + def params(self, *optionaldict: Any, **kwargs: Any) -> NoReturn: raise NotImplementedError("Immutable objects do not support copying") def _clone(self: _Self, **kw: Any) -> _Self: @@ -213,7 +219,7 @@ class Immutable: class SingletonConstant(Immutable): """Represent SQL constants like NULL, TRUE, FALSE""" - _is_singleton_constant = True + _is_singleton_constant: bool = True _singleton: SingletonConstant @@ -225,7 +231,7 @@ class SingletonConstant(Immutable): raise NotImplementedError() @classmethod - def _create_singleton(cls): + def _create_singleton(cls) -> None: obj = object.__new__(cls) obj.__init__() # type: ignore @@ -294,17 +300,17 @@ def _generative(fn: _Fn) -> _Fn: def _exclusive_against(*names: str, **kw: Any) -> Callable[[_Fn], _Fn]: - msgs = kw.pop("msgs", {}) + msgs: Dict[str, str] = kw.pop("msgs", {}) - defaults = kw.pop("defaults", {}) + defaults: Dict[str, str] = kw.pop("defaults", {}) - getters = [ + getters: List[Tuple[str, operator.attrgetter[Any], Optional[str]]] = [ (name, operator.attrgetter(name), defaults.get(name, None)) for name in names ] @util.decorator - def check(fn, *args, **kw): + def check(fn: _Fn, *args: Any, **kw: Any) -> Any: # make pylance happy by not including "self" in the argument # list self = args[0] @@ -353,12 +359,16 @@ def _cloned_intersection(a: Iterable[_CLE], b: Iterable[_CLE]) -> Set[_CLE]: The returned set is in terms of the entities present within 'a'. """ - all_overlap = set(_expand_cloned(a)).intersection(_expand_cloned(b)) + all_overlap: Set[_CLE] = set(_expand_cloned(a)).intersection( + _expand_cloned(b) + ) return {elem for elem in a if all_overlap.intersection(elem._cloned_set)} def _cloned_difference(a: Iterable[_CLE], b: Iterable[_CLE]) -> Set[_CLE]: - all_overlap = set(_expand_cloned(a)).intersection(_expand_cloned(b)) + all_overlap: Set[_CLE] = set(_expand_cloned(a)).intersection( + _expand_cloned(b) + ) return { elem for elem in a if not all_overlap.intersection(elem._cloned_set) } @@ -372,10 +382,10 @@ class _DialectArgView(MutableMapping[str, Any]): __slots__ = ("obj",) - def __init__(self, obj): + def __init__(self, obj: DialectKWArgs) -> None: self.obj = obj - def _key(self, key): + def _key(self, key: str) -> Tuple[str, str]: try: dialect, value_key = key.split("_", 1) except ValueError as err: @@ -383,7 +393,7 @@ class _DialectArgView(MutableMapping[str, Any]): else: return dialect, value_key - def __getitem__(self, key): + def __getitem__(self, key: str) -> Any: dialect, value_key = self._key(key) try: @@ -393,7 +403,7 @@ class _DialectArgView(MutableMapping[str, Any]): else: return opt[value_key] - def __setitem__(self, key, value): + def __setitem__(self, key: str, value: Any) -> None: try: dialect, value_key = self._key(key) except KeyError as err: @@ -403,17 +413,17 @@ class _DialectArgView(MutableMapping[str, Any]): else: self.obj.dialect_options[dialect][value_key] = value - def __delitem__(self, key): + def __delitem__(self, key: str) -> None: dialect, value_key = self._key(key) del self.obj.dialect_options[dialect][value_key] - def __len__(self): + def __len__(self) -> int: return sum( len(args._non_defaults) for args in self.obj.dialect_options.values() ) - def __iter__(self): + def __iter__(self) -> Generator[str, None, None]: return ( "%s_%s" % (dialect_name, value_name) for dialect_name in self.obj.dialect_options @@ -432,31 +442,31 @@ class _DialectArgDict(MutableMapping[str, Any]): """ - def __init__(self): - self._non_defaults = {} - self._defaults = {} + def __init__(self) -> None: + self._non_defaults: Dict[str, Any] = {} + self._defaults: Dict[str, Any] = {} - def __len__(self): + def __len__(self) -> int: return len(set(self._non_defaults).union(self._defaults)) - def __iter__(self): + def __iter__(self) -> Iterator[str]: return iter(set(self._non_defaults).union(self._defaults)) - def __getitem__(self, key): + def __getitem__(self, key: str) -> Any: if key in self._non_defaults: return self._non_defaults[key] else: return self._defaults[key] - def __setitem__(self, key, value): + def __setitem__(self, key: str, value: Any) -> None: self._non_defaults[key] = value - def __delitem__(self, key): + def __delitem__(self, key: str) -> None: del self._non_defaults[key] @util.preload_module("sqlalchemy.dialects") -def _kw_reg_for_dialect(dialect_name): +def _kw_reg_for_dialect(dialect_name: str) -> Optional[Dict[Any, Any]]: dialect_cls = util.preloaded.dialects.registry.load(dialect_name) if dialect_cls.construct_arguments is None: return None @@ -478,12 +488,14 @@ class DialectKWArgs: __slots__ = () - _dialect_kwargs_traverse_internals = [ + _dialect_kwargs_traverse_internals: List[Tuple[str, Any]] = [ ("dialect_options", InternalTraversal.dp_dialect_options) ] @classmethod - def argument_for(cls, dialect_name, argument_name, default): + def argument_for( + cls, dialect_name: str, argument_name: str, default: Any + ) -> None: """Add a new kind of dialect-specific keyword argument for this class. E.g.:: @@ -520,7 +532,9 @@ class DialectKWArgs: """ - construct_arg_dictionary = DialectKWArgs._kw_registry[dialect_name] + construct_arg_dictionary: Optional[Dict[Any, Any]] = ( + DialectKWArgs._kw_registry[dialect_name] + ) if construct_arg_dictionary is None: raise exc.ArgumentError( "Dialect '%s' does have keyword-argument " @@ -531,7 +545,7 @@ class DialectKWArgs: construct_arg_dictionary[cls][argument_name] = default @property - def dialect_kwargs(self): + def dialect_kwargs(self) -> _DialectArgView: """A collection of keyword arguments specified as dialect-specific options to this construct. @@ -552,14 +566,16 @@ class DialectKWArgs: return _DialectArgView(self) @property - def kwargs(self): + def kwargs(self) -> _DialectArgView: """A synonym for :attr:`.DialectKWArgs.dialect_kwargs`.""" return self.dialect_kwargs - _kw_registry = util.PopulateDict(_kw_reg_for_dialect) + _kw_registry: util.PopulateDict[str, Optional[Dict[Any, Any]]] = ( + util.PopulateDict(_kw_reg_for_dialect) + ) @classmethod - def _kw_reg_for_dialect_cls(cls, dialect_name): + def _kw_reg_for_dialect_cls(cls, dialect_name: str) -> _DialectArgDict: construct_arg_dictionary = DialectKWArgs._kw_registry[dialect_name] d = _DialectArgDict() @@ -572,7 +588,7 @@ class DialectKWArgs: return d @util.memoized_property - def dialect_options(self): + def dialect_options(self) -> util.PopulateDict[str, _DialectArgDict]: """A collection of keyword arguments specified as dialect-specific options to this construct. @@ -834,7 +850,7 @@ class Options(metaclass=_MetaOptions): ) super().__init_subclass__() - def __init__(self, **kw): + def __init__(self, **kw: Any) -> None: self.__dict__.update(kw) def __add__(self, other): @@ -859,7 +875,7 @@ class Options(metaclass=_MetaOptions): return False return True - def __repr__(self): + def __repr__(self) -> str: # TODO: fairly inefficient, used only in debugging right now. return "%s(%s)" % ( @@ -876,7 +892,7 @@ class Options(metaclass=_MetaOptions): return issubclass(cls, klass) @hybridmethod - def add_to_element(self, name, value): + def add_to_element(self, name: str, value: str) -> Any: return self + {name: getattr(self, name) + value} @hybridmethod @@ -890,7 +906,7 @@ class Options(metaclass=_MetaOptions): return cls._state_dict_const @classmethod - def safe_merge(cls, other): + def safe_merge(cls, other: "Options") -> Any: d = other._state_dict() # only support a merge with another object of our class @@ -916,8 +932,12 @@ class Options(metaclass=_MetaOptions): @classmethod def from_execution_options( - cls, key, attrs, exec_options, statement_exec_options - ): + cls, + key: str, + attrs: set[str], + exec_options: Mapping[str, Any], + statement_exec_options: Mapping[str, Any], + ) -> Tuple["Options", Mapping[str, Any]]: """process Options argument in terms of execution options. @@ -977,28 +997,32 @@ class CacheableOptions(Options, HasCacheKey): __slots__ = () @hybridmethod - def _gen_cache_key_inst(self, anon_map, bindparams): + def _gen_cache_key_inst( + self, anon_map: Any, bindparams: List[BindParameter[Any]] + ) -> Optional[Tuple[Any]]: return HasCacheKey._gen_cache_key(self, anon_map, bindparams) @_gen_cache_key_inst.classlevel - def _gen_cache_key(cls, anon_map, bindparams): + def _gen_cache_key( + cls, anon_map: "anon_map", bindparams: List[BindParameter[Any]] + ) -> Tuple[CacheableOptions, Any]: return (cls, ()) @hybridmethod - def _generate_cache_key(self): + def _generate_cache_key(self) -> Optional[CacheKey]: return HasCacheKey._generate_cache_key_for_object(self) class ExecutableOption(HasCopyInternals): __slots__ = () - _annotations = util.EMPTY_DICT + _annotations: _ImmutableExecuteOptions = util.EMPTY_DICT - __visit_name__ = "executable_option" + __visit_name__: str = "executable_option" - _is_has_cache_key = False + _is_has_cache_key: bool = False - _is_core = True + _is_core: bool = True def _clone(self, **kw): """Create a shallow copy of this ExecutableOption.""" @@ -1228,7 +1252,7 @@ class Executable(roles.StatementRole): supports_execution: bool = True _execution_options: _ImmutableExecuteOptions = util.EMPTY_DICT - _is_default_generator = False + _is_default_generator: bool = False _with_options: Tuple[ExecutableOption, ...] = () _compile_state_funcs: Tuple[ Tuple[Callable[[CompileState], None], Any], ... @@ -1244,13 +1268,13 @@ class Executable(roles.StatementRole): ("_propagate_attrs", ExtendedInternalTraversal.dp_propagate_attrs), ] - is_select = False - is_from_statement = False - is_update = False - is_insert = False - is_text = False - is_delete = False - is_dml = False + is_select: bool = False + is_from_statement: bool = False + is_update: bool = False + is_insert: bool = False + is_text: bool = False + is_delete: bool = False + is_dml: bool = False if TYPE_CHECKING: __visit_name__: str @@ -1283,7 +1307,7 @@ class Executable(roles.StatementRole): ) -> Any: ... @util.ro_non_memoized_property - def _all_selected_columns(self): + def _all_selected_columns(self) -> _SelectIterable: raise NotImplementedError() @property @@ -1552,7 +1576,7 @@ class SchemaVisitor(ClauseVisitor): """ - __traverse_options__ = {"schema_visitor": True} + __traverse_options__: Dict[str, Any] = {"schema_visitor": True} class _SentinelDefaultCharacterization(Enum): @@ -1587,7 +1611,7 @@ class _ColumnMetrics(Generic[_COL_co]): def __init__( self, collection: ColumnCollection[Any, _COL_co], col: _COL_co - ): + ) -> None: self.column = col # proxy_index being non-empty means it was initialized. @@ -1597,10 +1621,10 @@ class _ColumnMetrics(Generic[_COL_co]): for eps_col in col._expanded_proxy_set: pi[eps_col].add(self) - def get_expanded_proxy_set(self): + def get_expanded_proxy_set(self) -> FrozenSet[ColumnElement[Any]]: return self.column._expanded_proxy_set - def dispose(self, collection): + def dispose(self, collection: ColumnCollection[_COLKEY, _COL_co]) -> None: pi = collection._proxy_index if not pi: return @@ -1733,7 +1757,7 @@ class ColumnCollection(Generic[_COLKEY, _COL_co]): """ - __slots__ = "_collection", "_index", "_colset", "_proxy_index" + __slots__ = ("_collection", "_index", "_colset", "_proxy_index") _collection: List[Tuple[_COLKEY, _COL_co, _ColumnMetrics[_COL_co]]] _index: Dict[Union[None, str, int], Tuple[_COLKEY, _COL_co]] @@ -1852,7 +1876,7 @@ class ColumnCollection(Generic[_COLKEY, _COL_co]): else: return True - def compare(self, other: ColumnCollection[Any, Any]) -> bool: + def compare(self, other: ColumnCollection[_COLKEY, _COL_co]) -> bool: """Compare this :class:`_expression.ColumnCollection` to another based on the names of the keys""" @@ -1903,7 +1927,7 @@ class ColumnCollection(Generic[_COLKEY, _COL_co]): :class:`_sql.ColumnCollection`.""" raise NotImplementedError() - def remove(self, column: Any) -> None: + def remove(self, column: Any) -> NoReturn: raise NotImplementedError() def update(self, iter_: Any) -> NoReturn: @@ -1912,7 +1936,7 @@ class ColumnCollection(Generic[_COLKEY, _COL_co]): raise NotImplementedError() # https://github.com/python/mypy/issues/4266 - __hash__ = None # type: ignore + __hash__: Optional[int] = None # type: ignore def _populate_separate_keys( self, iter_: Iterable[Tuple[_COLKEY, _COL_co]] @@ -2008,7 +2032,7 @@ class ColumnCollection(Generic[_COLKEY, _COL_co]): return ReadOnlyColumnCollection(self) - def _init_proxy_index(self): + def _init_proxy_index(self) -> None: """populate the "proxy index", if empty. proxy index is added in 2.0 to provide more efficient operation @@ -2252,7 +2276,7 @@ class DedupeColumnCollection(ColumnCollection[str, _NAMEDCOL]): def extend(self, iter_: Iterable[_NAMEDCOL]) -> None: self._populate_separate_keys((col.key, col) for col in iter_) - def remove(self, column: _NAMEDCOL) -> None: + def remove(self, column: _NAMEDCOL) -> None: # type: ignore[override] if column not in self._colset: raise ValueError( "Can't remove column %r; column is not in this collection" @@ -2370,17 +2394,17 @@ class ReadOnlyColumnCollection( ): __slots__ = ("_parent",) - def __init__(self, collection): + def __init__(self, collection: ColumnCollection[_COLKEY, _COL_co]): object.__setattr__(self, "_parent", collection) object.__setattr__(self, "_colset", collection._colset) object.__setattr__(self, "_index", collection._index) object.__setattr__(self, "_collection", collection._collection) object.__setattr__(self, "_proxy_index", collection._proxy_index) - def __getstate__(self): + def __getstate__(self) -> Dict[str, _COL_co]: return {"_parent": self._parent} - def __setstate__(self, state): + def __setstate__(self, state: Dict[str, Any]) -> None: parent = state["_parent"] self.__init__(parent) # type: ignore @@ -2395,10 +2419,10 @@ class ReadOnlyColumnCollection( class ColumnSet(util.OrderedSet["ColumnClause[Any]"]): - def contains_column(self, col): + def contains_column(self, col: ColumnClause[Any]) -> bool: return col in self - def extend(self, cols): + def extend(self, cols: Iterable[Any]) -> None: for col in cols: self.add(col) @@ -2410,7 +2434,7 @@ class ColumnSet(util.OrderedSet["ColumnClause[Any]"]): l.append(c == local) return elements.and_(*l) - def __hash__(self): # type: ignore[override] + def __hash__(self) -> int: # type: ignore[override] return hash(tuple(x for x in self)) diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py index 349f189302..413903b4f6 100644 --- a/lib/sqlalchemy/sql/selectable.py +++ b/lib/sqlalchemy/sql/selectable.py @@ -1346,7 +1346,7 @@ class Join(roles.DMLTableRole, FromClause): c for c in self.right.c ] - primary_key.extend( # type: ignore + primary_key.extend( sqlutil.reduce_columns( (c for c in _columns if c.primary_key), self.onclause )