]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
typing: improve type coverage in sql.base
authorKapilDagur <kapildagur1306@gmail.com>
Mon, 14 Jul 2025 19:36:30 +0000 (15:36 -0400)
committerFederico Caselli <cfederico87@gmail.com>
Tue, 15 Jul 2025 07:55:42 +0000 (07:55 +0000)
improve type coverage in `sqlalchemy.sql.base`

References: #6810
Closes: #12707
Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/12707
Pull-request-sha: 7374212dfc88a43f381f8380d9b4ac193f5ed10b

Change-Id: Ied0676f420bc27ae033f0a5e6e22d806d20f4404

lib/sqlalchemy/sql/base.py
lib/sqlalchemy/sql/selectable.py

index 73f809198461a624dca709f95342ec71b0d29d2b..f428dff185b08763b18095d119c534ffcbc40330 100644 (file)
@@ -21,7 +21,9 @@ from typing import Any
 from typing import Callable
 from typing import cast
 from typing import Dict
+from typing import Final
 from typing import FrozenSet
+from typing import Generator
 from typing import Generic
 from typing import Iterable
 from typing import Iterator
@@ -67,6 +69,8 @@ if TYPE_CHECKING:
     from ._orm_types import DMLStrategyArgument
     from ._orm_types import SynchronizeSessionArgument
     from ._typing import _CLE
+    from ._util_cy import anon_map
+    from .cache_key import CacheKey
     from .compiler import SQLCompiler
     from .dml import Delete
     from .dml import Insert
@@ -115,7 +119,7 @@ class _NoArg(Enum):
         return f"_NoArg.{self.name}"
 
 
-NO_ARG = _NoArg.NO_ARG
+NO_ARG: Final = _NoArg.NO_ARG
 
 
 class _NoneName(Enum):
@@ -123,7 +127,7 @@ class _NoneName(Enum):
     """indicate a 'deferred' name that was ultimately the value None."""
 
 
-_NONE_NAME = _NoneName.NONE_NAME
+_NONE_NAME: Final = _NoneName.NONE_NAME
 
 _T = TypeVar("_T", bound=Any)
 
@@ -158,7 +162,9 @@ class _DefaultDescriptionTuple(NamedTuple):
         )
 
 
-_never_select_column = operator.attrgetter("_omit_from_statements")
+_never_select_column: operator.attrgetter[Any] = operator.attrgetter(
+    "_omit_from_statements"
+)
 
 
 class _EntityNamespace(Protocol):
@@ -193,12 +199,12 @@ class Immutable:
 
     __slots__ = ()
 
-    _is_immutable = True
+    _is_immutable: bool = True
 
-    def unique_params(self, *optionaldict, **kwargs):
+    def unique_params(self, *optionaldict: Any, **kwargs: Any) -> NoReturn:
         raise NotImplementedError("Immutable objects do not support copying")
 
-    def params(self, *optionaldict, **kwargs):
+    def params(self, *optionaldict: Any, **kwargs: Any) -> NoReturn:
         raise NotImplementedError("Immutable objects do not support copying")
 
     def _clone(self: _Self, **kw: Any) -> _Self:
@@ -213,7 +219,7 @@ class Immutable:
 class SingletonConstant(Immutable):
     """Represent SQL constants like NULL, TRUE, FALSE"""
 
-    _is_singleton_constant = True
+    _is_singleton_constant: bool = True
 
     _singleton: SingletonConstant
 
@@ -225,7 +231,7 @@ class SingletonConstant(Immutable):
         raise NotImplementedError()
 
     @classmethod
-    def _create_singleton(cls):
+    def _create_singleton(cls) -> None:
         obj = object.__new__(cls)
         obj.__init__()  # type: ignore
 
@@ -294,17 +300,17 @@ def _generative(fn: _Fn) -> _Fn:
 
 
 def _exclusive_against(*names: str, **kw: Any) -> Callable[[_Fn], _Fn]:
-    msgs = kw.pop("msgs", {})
+    msgs: Dict[str, str] = kw.pop("msgs", {})
 
-    defaults = kw.pop("defaults", {})
+    defaults: Dict[str, str] = kw.pop("defaults", {})
 
-    getters = [
+    getters: List[Tuple[str, operator.attrgetter[Any], Optional[str]]] = [
         (name, operator.attrgetter(name), defaults.get(name, None))
         for name in names
     ]
 
     @util.decorator
-    def check(fn, *args, **kw):
+    def check(fn: _Fn, *args: Any, **kw: Any) -> Any:
         # make pylance happy by not including "self" in the argument
         # list
         self = args[0]
@@ -353,12 +359,16 @@ def _cloned_intersection(a: Iterable[_CLE], b: Iterable[_CLE]) -> Set[_CLE]:
     The returned set is in terms of the entities present within 'a'.
 
     """
-    all_overlap = set(_expand_cloned(a)).intersection(_expand_cloned(b))
+    all_overlap: Set[_CLE] = set(_expand_cloned(a)).intersection(
+        _expand_cloned(b)
+    )
     return {elem for elem in a if all_overlap.intersection(elem._cloned_set)}
 
 
 def _cloned_difference(a: Iterable[_CLE], b: Iterable[_CLE]) -> Set[_CLE]:
-    all_overlap = set(_expand_cloned(a)).intersection(_expand_cloned(b))
+    all_overlap: Set[_CLE] = set(_expand_cloned(a)).intersection(
+        _expand_cloned(b)
+    )
     return {
         elem for elem in a if not all_overlap.intersection(elem._cloned_set)
     }
@@ -372,10 +382,10 @@ class _DialectArgView(MutableMapping[str, Any]):
 
     __slots__ = ("obj",)
 
-    def __init__(self, obj):
+    def __init__(self, obj: DialectKWArgs) -> None:
         self.obj = obj
 
-    def _key(self, key):
+    def _key(self, key: str) -> Tuple[str, str]:
         try:
             dialect, value_key = key.split("_", 1)
         except ValueError as err:
@@ -383,7 +393,7 @@ class _DialectArgView(MutableMapping[str, Any]):
         else:
             return dialect, value_key
 
-    def __getitem__(self, key):
+    def __getitem__(self, key: str) -> Any:
         dialect, value_key = self._key(key)
 
         try:
@@ -393,7 +403,7 @@ class _DialectArgView(MutableMapping[str, Any]):
         else:
             return opt[value_key]
 
-    def __setitem__(self, key, value):
+    def __setitem__(self, key: str, value: Any) -> None:
         try:
             dialect, value_key = self._key(key)
         except KeyError as err:
@@ -403,17 +413,17 @@ class _DialectArgView(MutableMapping[str, Any]):
         else:
             self.obj.dialect_options[dialect][value_key] = value
 
-    def __delitem__(self, key):
+    def __delitem__(self, key: str) -> None:
         dialect, value_key = self._key(key)
         del self.obj.dialect_options[dialect][value_key]
 
-    def __len__(self):
+    def __len__(self) -> int:
         return sum(
             len(args._non_defaults)
             for args in self.obj.dialect_options.values()
         )
 
-    def __iter__(self):
+    def __iter__(self) -> Generator[str, None, None]:
         return (
             "%s_%s" % (dialect_name, value_name)
             for dialect_name in self.obj.dialect_options
@@ -432,31 +442,31 @@ class _DialectArgDict(MutableMapping[str, Any]):
 
     """
 
-    def __init__(self):
-        self._non_defaults = {}
-        self._defaults = {}
+    def __init__(self) -> None:
+        self._non_defaults: Dict[str, Any] = {}
+        self._defaults: Dict[str, Any] = {}
 
-    def __len__(self):
+    def __len__(self) -> int:
         return len(set(self._non_defaults).union(self._defaults))
 
-    def __iter__(self):
+    def __iter__(self) -> Iterator[str]:
         return iter(set(self._non_defaults).union(self._defaults))
 
-    def __getitem__(self, key):
+    def __getitem__(self, key: str) -> Any:
         if key in self._non_defaults:
             return self._non_defaults[key]
         else:
             return self._defaults[key]
 
-    def __setitem__(self, key, value):
+    def __setitem__(self, key: str, value: Any) -> None:
         self._non_defaults[key] = value
 
-    def __delitem__(self, key):
+    def __delitem__(self, key: str) -> None:
         del self._non_defaults[key]
 
 
 @util.preload_module("sqlalchemy.dialects")
-def _kw_reg_for_dialect(dialect_name):
+def _kw_reg_for_dialect(dialect_name: str) -> Optional[Dict[Any, Any]]:
     dialect_cls = util.preloaded.dialects.registry.load(dialect_name)
     if dialect_cls.construct_arguments is None:
         return None
@@ -478,12 +488,14 @@ class DialectKWArgs:
 
     __slots__ = ()
 
-    _dialect_kwargs_traverse_internals = [
+    _dialect_kwargs_traverse_internals: List[Tuple[str, Any]] = [
         ("dialect_options", InternalTraversal.dp_dialect_options)
     ]
 
     @classmethod
-    def argument_for(cls, dialect_name, argument_name, default):
+    def argument_for(
+        cls, dialect_name: str, argument_name: str, default: Any
+    ) -> None:
         """Add a new kind of dialect-specific keyword argument for this class.
 
         E.g.::
@@ -520,7 +532,9 @@ class DialectKWArgs:
 
         """
 
-        construct_arg_dictionary = DialectKWArgs._kw_registry[dialect_name]
+        construct_arg_dictionary: Optional[Dict[Any, Any]] = (
+            DialectKWArgs._kw_registry[dialect_name]
+        )
         if construct_arg_dictionary is None:
             raise exc.ArgumentError(
                 "Dialect '%s' does have keyword-argument "
@@ -531,7 +545,7 @@ class DialectKWArgs:
         construct_arg_dictionary[cls][argument_name] = default
 
     @property
-    def dialect_kwargs(self):
+    def dialect_kwargs(self) -> _DialectArgView:
         """A collection of keyword arguments specified as dialect-specific
         options to this construct.
 
@@ -552,14 +566,16 @@ class DialectKWArgs:
         return _DialectArgView(self)
 
     @property
-    def kwargs(self):
+    def kwargs(self) -> _DialectArgView:
         """A synonym for :attr:`.DialectKWArgs.dialect_kwargs`."""
         return self.dialect_kwargs
 
-    _kw_registry = util.PopulateDict(_kw_reg_for_dialect)
+    _kw_registry: util.PopulateDict[str, Optional[Dict[Any, Any]]] = (
+        util.PopulateDict(_kw_reg_for_dialect)
+    )
 
     @classmethod
-    def _kw_reg_for_dialect_cls(cls, dialect_name):
+    def _kw_reg_for_dialect_cls(cls, dialect_name: str) -> _DialectArgDict:
         construct_arg_dictionary = DialectKWArgs._kw_registry[dialect_name]
         d = _DialectArgDict()
 
@@ -572,7 +588,7 @@ class DialectKWArgs:
         return d
 
     @util.memoized_property
-    def dialect_options(self):
+    def dialect_options(self) -> util.PopulateDict[str, _DialectArgDict]:
         """A collection of keyword arguments specified as dialect-specific
         options to this construct.
 
@@ -834,7 +850,7 @@ class Options(metaclass=_MetaOptions):
         )
         super().__init_subclass__()
 
-    def __init__(self, **kw):
+    def __init__(self, **kw: Any) -> None:
         self.__dict__.update(kw)
 
     def __add__(self, other):
@@ -859,7 +875,7 @@ class Options(metaclass=_MetaOptions):
                 return False
         return True
 
-    def __repr__(self):
+    def __repr__(self) -> str:
         # TODO: fairly inefficient, used only in debugging right now.
 
         return "%s(%s)" % (
@@ -876,7 +892,7 @@ class Options(metaclass=_MetaOptions):
         return issubclass(cls, klass)
 
     @hybridmethod
-    def add_to_element(self, name, value):
+    def add_to_element(self, name: str, value: str) -> Any:
         return self + {name: getattr(self, name) + value}
 
     @hybridmethod
@@ -890,7 +906,7 @@ class Options(metaclass=_MetaOptions):
         return cls._state_dict_const
 
     @classmethod
-    def safe_merge(cls, other):
+    def safe_merge(cls, other: "Options") -> Any:
         d = other._state_dict()
 
         # only support a merge with another object of our class
@@ -916,8 +932,12 @@ class Options(metaclass=_MetaOptions):
 
     @classmethod
     def from_execution_options(
-        cls, key, attrs, exec_options, statement_exec_options
-    ):
+        cls,
+        key: str,
+        attrs: set[str],
+        exec_options: Mapping[str, Any],
+        statement_exec_options: Mapping[str, Any],
+    ) -> Tuple["Options", Mapping[str, Any]]:
         """process Options argument in terms of execution options.
 
 
@@ -977,28 +997,32 @@ class CacheableOptions(Options, HasCacheKey):
     __slots__ = ()
 
     @hybridmethod
-    def _gen_cache_key_inst(self, anon_map, bindparams):
+    def _gen_cache_key_inst(
+        self, anon_map: Any, bindparams: List[BindParameter[Any]]
+    ) -> Optional[Tuple[Any]]:
         return HasCacheKey._gen_cache_key(self, anon_map, bindparams)
 
     @_gen_cache_key_inst.classlevel
-    def _gen_cache_key(cls, anon_map, bindparams):
+    def _gen_cache_key(
+        cls, anon_map: "anon_map", bindparams: List[BindParameter[Any]]
+    ) -> Tuple[CacheableOptions, Any]:
         return (cls, ())
 
     @hybridmethod
-    def _generate_cache_key(self):
+    def _generate_cache_key(self) -> Optional[CacheKey]:
         return HasCacheKey._generate_cache_key_for_object(self)
 
 
 class ExecutableOption(HasCopyInternals):
     __slots__ = ()
 
-    _annotations = util.EMPTY_DICT
+    _annotations: _ImmutableExecuteOptions = util.EMPTY_DICT
 
-    __visit_name__ = "executable_option"
+    __visit_name__: str = "executable_option"
 
-    _is_has_cache_key = False
+    _is_has_cache_key: bool = False
 
-    _is_core = True
+    _is_core: bool = True
 
     def _clone(self, **kw):
         """Create a shallow copy of this ExecutableOption."""
@@ -1228,7 +1252,7 @@ class Executable(roles.StatementRole):
 
     supports_execution: bool = True
     _execution_options: _ImmutableExecuteOptions = util.EMPTY_DICT
-    _is_default_generator = False
+    _is_default_generator: bool = False
     _with_options: Tuple[ExecutableOption, ...] = ()
     _compile_state_funcs: Tuple[
         Tuple[Callable[[CompileState], None], Any], ...
@@ -1244,13 +1268,13 @@ class Executable(roles.StatementRole):
         ("_propagate_attrs", ExtendedInternalTraversal.dp_propagate_attrs),
     ]
 
-    is_select = False
-    is_from_statement = False
-    is_update = False
-    is_insert = False
-    is_text = False
-    is_delete = False
-    is_dml = False
+    is_select: bool = False
+    is_from_statement: bool = False
+    is_update: bool = False
+    is_insert: bool = False
+    is_text: bool = False
+    is_delete: bool = False
+    is_dml: bool = False
 
     if TYPE_CHECKING:
         __visit_name__: str
@@ -1283,7 +1307,7 @@ class Executable(roles.StatementRole):
         ) -> Any: ...
 
     @util.ro_non_memoized_property
-    def _all_selected_columns(self):
+    def _all_selected_columns(self) -> _SelectIterable:
         raise NotImplementedError()
 
     @property
@@ -1552,7 +1576,7 @@ class SchemaVisitor(ClauseVisitor):
 
     """
 
-    __traverse_options__ = {"schema_visitor": True}
+    __traverse_options__: Dict[str, Any] = {"schema_visitor": True}
 
 
 class _SentinelDefaultCharacterization(Enum):
@@ -1587,7 +1611,7 @@ class _ColumnMetrics(Generic[_COL_co]):
 
     def __init__(
         self, collection: ColumnCollection[Any, _COL_co], col: _COL_co
-    ):
+    ) -> None:
         self.column = col
 
         # proxy_index being non-empty means it was initialized.
@@ -1597,10 +1621,10 @@ class _ColumnMetrics(Generic[_COL_co]):
             for eps_col in col._expanded_proxy_set:
                 pi[eps_col].add(self)
 
-    def get_expanded_proxy_set(self):
+    def get_expanded_proxy_set(self) -> FrozenSet[ColumnElement[Any]]:
         return self.column._expanded_proxy_set
 
-    def dispose(self, collection):
+    def dispose(self, collection: ColumnCollection[_COLKEY, _COL_co]) -> None:
         pi = collection._proxy_index
         if not pi:
             return
@@ -1733,7 +1757,7 @@ class ColumnCollection(Generic[_COLKEY, _COL_co]):
 
     """
 
-    __slots__ = "_collection", "_index", "_colset", "_proxy_index"
+    __slots__ = ("_collection", "_index", "_colset", "_proxy_index")
 
     _collection: List[Tuple[_COLKEY, _COL_co, _ColumnMetrics[_COL_co]]]
     _index: Dict[Union[None, str, int], Tuple[_COLKEY, _COL_co]]
@@ -1852,7 +1876,7 @@ class ColumnCollection(Generic[_COLKEY, _COL_co]):
         else:
             return True
 
-    def compare(self, other: ColumnCollection[Any, Any]) -> bool:
+    def compare(self, other: ColumnCollection[_COLKEY, _COL_co]) -> bool:
         """Compare this :class:`_expression.ColumnCollection` to another
         based on the names of the keys"""
 
@@ -1903,7 +1927,7 @@ class ColumnCollection(Generic[_COLKEY, _COL_co]):
         :class:`_sql.ColumnCollection`."""
         raise NotImplementedError()
 
-    def remove(self, column: Any) -> None:
+    def remove(self, column: Any) -> NoReturn:
         raise NotImplementedError()
 
     def update(self, iter_: Any) -> NoReturn:
@@ -1912,7 +1936,7 @@ class ColumnCollection(Generic[_COLKEY, _COL_co]):
         raise NotImplementedError()
 
     # https://github.com/python/mypy/issues/4266
-    __hash__ = None  # type: ignore
+    __hash__: Optional[int] = None  # type: ignore
 
     def _populate_separate_keys(
         self, iter_: Iterable[Tuple[_COLKEY, _COL_co]]
@@ -2008,7 +2032,7 @@ class ColumnCollection(Generic[_COLKEY, _COL_co]):
 
         return ReadOnlyColumnCollection(self)
 
-    def _init_proxy_index(self):
+    def _init_proxy_index(self) -> None:
         """populate the "proxy index", if empty.
 
         proxy index is added in 2.0 to provide more efficient operation
@@ -2252,7 +2276,7 @@ class DedupeColumnCollection(ColumnCollection[str, _NAMEDCOL]):
     def extend(self, iter_: Iterable[_NAMEDCOL]) -> None:
         self._populate_separate_keys((col.key, col) for col in iter_)
 
-    def remove(self, column: _NAMEDCOL) -> None:
+    def remove(self, column: _NAMEDCOL) -> None:  # type: ignore[override]
         if column not in self._colset:
             raise ValueError(
                 "Can't remove column %r; column is not in this collection"
@@ -2370,17 +2394,17 @@ class ReadOnlyColumnCollection(
 ):
     __slots__ = ("_parent",)
 
-    def __init__(self, collection):
+    def __init__(self, collection: ColumnCollection[_COLKEY, _COL_co]):
         object.__setattr__(self, "_parent", collection)
         object.__setattr__(self, "_colset", collection._colset)
         object.__setattr__(self, "_index", collection._index)
         object.__setattr__(self, "_collection", collection._collection)
         object.__setattr__(self, "_proxy_index", collection._proxy_index)
 
-    def __getstate__(self):
+    def __getstate__(self) -> Dict[str, _COL_co]:
         return {"_parent": self._parent}
 
-    def __setstate__(self, state):
+    def __setstate__(self, state: Dict[str, Any]) -> None:
         parent = state["_parent"]
         self.__init__(parent)  # type: ignore
 
@@ -2395,10 +2419,10 @@ class ReadOnlyColumnCollection(
 
 
 class ColumnSet(util.OrderedSet["ColumnClause[Any]"]):
-    def contains_column(self, col):
+    def contains_column(self, col: ColumnClause[Any]) -> bool:
         return col in self
 
-    def extend(self, cols):
+    def extend(self, cols: Iterable[Any]) -> None:
         for col in cols:
             self.add(col)
 
@@ -2410,7 +2434,7 @@ class ColumnSet(util.OrderedSet["ColumnClause[Any]"]):
                     l.append(c == local)
         return elements.and_(*l)
 
-    def __hash__(self):  # type: ignore[override]
+    def __hash__(self) -> int:  # type: ignore[override]
         return hash(tuple(x for x in self))
 
 
index 349f189302aa268d9a0682d6a92031dcaef19f6e..413903b4f689aa94599a7bf8e24470ec3feb2787 100644 (file)
@@ -1346,7 +1346,7 @@ class Join(roles.DMLTableRole, FromClause):
             c for c in self.right.c
         ]
 
-        primary_key.extend(  # type: ignore
+        primary_key.extend(
             sqlutil.reduce_columns(
                 (c for c in _columns if c.primary_key), self.onclause
             )