]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Type annotations for sqlalchemy.sql.selectable
authorDzmitar <17720985+dzmitar@users.noreply.github.com>
Thu, 12 Jan 2023 15:47:24 +0000 (10:47 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 13 Jan 2023 17:05:15 +0000 (12:05 -0500)
Co-authored-by: Mike Bayer <mike_mp@zzzcomputing.com>
Closes: #9028
Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/9028
Pull-request-sha: e2f8ddeac0b08feaad917285e988acf1e9465a26

Change-Id: I5caad31bfeeed2d224657a55f067ba1d86b8733f

lib/sqlalchemy/sql/_typing.py
lib/sqlalchemy/sql/base.py
lib/sqlalchemy/sql/coercions.py
lib/sqlalchemy/sql/elements.py
lib/sqlalchemy/sql/functions.py
lib/sqlalchemy/sql/roles.py
lib/sqlalchemy/sql/selectable.py

index 78e196efc21b970160470b63137d9aadc23c9d58..a120629caa14eb2541d896534948c3127a29af0e 100644 (file)
@@ -70,6 +70,8 @@ _T = TypeVar("_T", bound=Any)
 
 _CE = TypeVar("_CE", bound="ColumnElement[Any]")
 
+_CLE = TypeVar("_CLE", bound="ClauseElement")
+
 
 class _HasClauseElement(Protocol):
     """indicates a class that has a __clause_element__() method"""
index 25e214bd39e821a5ed01c03ac731dfbef6e6136d..96ebc782494777a755e1e5911f7a43a38a5ccefd 100644 (file)
@@ -66,6 +66,7 @@ if TYPE_CHECKING:
     from . import type_api
     from ._orm_types import DMLStrategyArgument
     from ._orm_types import SynchronizeSessionArgument
+    from ._typing import _CLE
     from .elements import BindParameter
     from .elements import ClauseList
     from .elements import ColumnClause  # noqa
@@ -282,7 +283,9 @@ def _clone(element, **kw):
     return element._clone(**kw)
 
 
-def _expand_cloned(elements):
+def _expand_cloned(
+    elements: Iterable[_CLE],
+) -> Iterable[_CLE]:
     """expand the given set of ClauseElements to be the set of all 'cloned'
     predecessors.
 
@@ -291,7 +294,7 @@ def _expand_cloned(elements):
     return itertools.chain(*[x._cloned_set for x in elements])
 
 
-def _cloned_intersection(a, b):
+def _cloned_intersection(a: Iterable[_CLE], b: Iterable[_CLE]) -> Set[_CLE]:
     """return the intersection of sets a and b, counting
     any overlap between 'cloned' predecessors.
 
@@ -302,7 +305,7 @@ def _cloned_intersection(a, b):
     return {elem for elem in a if all_overlap.intersection(elem._cloned_set)}
 
 
-def _cloned_difference(a, b):
+def _cloned_difference(a: Iterable[_CLE], b: Iterable[_CLE]) -> Set[_CLE]:
     all_overlap = set(_expand_cloned(a)).intersection(_expand_cloned(b))
     return {
         elem for elem in a if not all_overlap.intersection(elem._cloned_set)
index 30e8bd77401034ac9ffc288d17f3af87dac9abcb..9fe65b0cd6401f2d25bafff1afabab075e1e416c 100644 (file)
@@ -14,6 +14,7 @@ import re
 import typing
 from typing import Any
 from typing import Callable
+from typing import cast
 from typing import Dict
 from typing import Iterable
 from typing import Iterator
@@ -21,6 +22,7 @@ from typing import List
 from typing import NoReturn
 from typing import Optional
 from typing import overload
+from typing import Sequence
 from typing import Tuple
 from typing import Type
 from typing import TYPE_CHECKING
@@ -140,7 +142,11 @@ def _document_text_coercion(
     )
 
 
-def _expression_collection_was_a_list(attrname, fnname, args):
+def _expression_collection_was_a_list(
+    attrname: str,
+    fnname: str,
+    args: Union[Sequence[_T], Sequence[Sequence[_T]]],
+) -> Sequence[_T]:
     if args and isinstance(args[0], (list, set, dict)) and len(args) == 1:
         if isinstance(args[0], list):
             raise exc.ArgumentError(
@@ -149,9 +155,9 @@ def _expression_collection_was_a_list(attrname, fnname, args):
                 "of items, is now passed as a series of positional "
                 "elements, rather than as a list. "
             )
-        return args[0]
+        return cast("Sequence[_T]", args[0])
 
-    return args
+    return cast("Sequence[_T]", args)
 
 
 @overload
index 748e9504b3c133daaea22c2c5eca1061f31f0590..f9c7cac23ed52e6ea312cd49788ebdd0304eba2d 100644 (file)
@@ -582,7 +582,9 @@ class ClauseElement(
         """
         return traversals.compare(self, other, **kw)
 
-    def self_group(self, against: Optional[OperatorType] = None) -> Any:
+    def self_group(
+        self, against: Optional[OperatorType] = None
+    ) -> ClauseElement:
         """Apply a 'grouping' to this :class:`_expression.ClauseElement`.
 
         This method is overridden by subclasses to return a "grouping"
@@ -609,7 +611,7 @@ class ClauseElement(
         """
         return self
 
-    def _ungroup(self):
+    def _ungroup(self) -> ClauseElement:
         """Return this :class:`_expression.ClauseElement`
         without any groupings.
         """
@@ -3452,6 +3454,8 @@ class UnaryExpression(ColumnElement[_T]):
         ("modifier", InternalTraversal.dp_operator),
     ]
 
+    element: ClauseElement
+
     def __init__(
         self,
         element: ColumnElement[Any],
index 902811037e7e73deb372de5463da35c5a09ddc57..ca30ab5ead4833acf0af576ee09b26877be0b75e 100644 (file)
@@ -628,7 +628,7 @@ class FunctionElement(Executable, ColumnElement[_T], FromClause, Generative):
 
         return TableValuedAlias._construct(
             self,
-            name,
+            name=name,
             table_value_type=self.type,
             joins_implicitly=joins_implicitly,
         )
index 35d1708e2431fd4795b9a920e37ac5b139100b8a..f8aac70b998dc3aad81fc12f85dc1505d2444ff1 100644 (file)
@@ -225,7 +225,7 @@ class AnonymizedFromClauseRole(StrictFromClauseRole):
     if TYPE_CHECKING:
 
         def _anonymous_fromclause(
-            self, name: Optional[str] = None, flat: bool = False
+            self, *, name: Optional[str] = None, flat: bool = False
         ) -> FromClause:
             ...
 
index 54230d58a68b21d18b8a9ea937735c297b37c2b8..511c4855147d16b78a29baa9931ff4297765d521 100644 (file)
@@ -4,7 +4,6 @@
 #
 # This module is part of SQLAlchemy and is released under
 # the MIT License: https://www.opensource.org/licenses/mit-license.php
-# mypy: allow-untyped-defs, allow-untyped-calls
 
 """The :class:`_expression.FromClause` class of SQL expression elements,
 representing
@@ -54,6 +53,7 @@ from ._typing import is_column_element
 from ._typing import is_select_statement
 from ._typing import is_subquery
 from ._typing import is_table
+from ._typing import is_text_clause
 from .annotation import Annotated
 from .annotation import SupportsCloneAnnotations
 from .base import _clone
@@ -131,6 +131,7 @@ if TYPE_CHECKING:
     from .compiler import SQLCompiler
     from .dml import Delete
     from .dml import Update
+    from .elements import BinaryExpression
     from .elements import KeyedColumnElement
     from .elements import Label
     from .elements import NamedColumn
@@ -138,6 +139,7 @@ if TYPE_CHECKING:
     from .functions import Function
     from .schema import ForeignKey
     from .schema import ForeignKeyConstraint
+    from .sqltypes import TableValueType
     from .type_api import TypeEngine
     from .visitors import _CloneCallableType
 
@@ -153,6 +155,10 @@ class _JoinTargetProtocol(Protocol):
     def _from_objects(self) -> List[FromClause]:
         ...
 
+    @util.ro_non_memoized_property
+    def entity_namespace(self) -> _EntityNamespace:
+        ...
+
 
 _JoinTargetElement = Union["FromClause", _JoinTargetProtocol]
 _OnClauseElement = Union["ColumnElement[bool]", _JoinTargetProtocol]
@@ -295,7 +301,7 @@ class Selectable(ReturnsRows):
             :ref:`tutorial_lateral_correlation` -  overview of usage.
 
         """
-        return Lateral._construct(self, name)
+        return Lateral._construct(self, name=name)
 
     @util.deprecated(
         "1.4",
@@ -757,7 +763,7 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable):
 
         """
 
-        return Alias._construct(self, name)
+        return Alias._construct(self, name=name)
 
     def tablesample(
         self,
@@ -778,7 +784,9 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable):
             :func:`_expression.tablesample` - usage guidelines and parameters
 
         """
-        return TableSample._construct(self, sampling, name, seed)
+        return TableSample._construct(
+            self, sampling=sampling, name=name, seed=seed
+        )
 
     def is_derived_from(self, fromclause: Optional[FromClause]) -> bool:
         """Return ``True`` if this :class:`_expression.FromClause` is
@@ -991,8 +999,8 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable):
         self._reset_column_collection()
 
     def _anonymous_fromclause(
-        self, name: Optional[str] = None, flat: bool = False
-    ) -> NamedFromClause:
+        self, *, name: Optional[str] = None, flat: bool = False
+    ) -> FromClause:
         return self.alias(name=name)
 
     if TYPE_CHECKING:
@@ -1252,7 +1260,7 @@ class Join(roles.DMLTableRole, FromClause):
         return FromGrouping(self)
 
     @util.preload_module("sqlalchemy.sql.util")
-    def _populate_column_collection(self):
+    def _populate_column_collection(self) -> None:
         sqlutil = util.preloaded.sql_util
         columns: List[KeyedColumnElement[Any]] = [c for c in self.left.c] + [
             c for c in self.right.c
@@ -1291,10 +1299,14 @@ class Join(roles.DMLTableRole, FromClause):
         # set up a special replace function that will replace for
         # ColumnClause with parent table referring to those
         # replaced FromClause objects
-        def replace(obj, **kw):
+        def replace(
+            obj: Union[BinaryExpression[Any], ColumnClause[Any]],
+            **kw: Any,
+        ) -> Optional[KeyedColumnElement[ColumnElement[Any]]]:
             if isinstance(obj, ColumnClause) and obj.table in new_froms:
                 newelem = new_froms[obj.table].corresponding_column(obj)
                 return newelem
+            return None
 
         kw["replace"] = replace
 
@@ -1311,7 +1323,9 @@ class Join(roles.DMLTableRole, FromClause):
         self.right._refresh_for_new_column(column)
 
     def _match_primaries(
-        self, left: FromClause, right: FromClause
+        self,
+        left: FromClause,
+        right: FromClause,
     ) -> ColumnElement[bool]:
         if isinstance(left, Join):
             left_right = left.right
@@ -1460,8 +1474,12 @@ class Join(roles.DMLTableRole, FromClause):
 
     @classmethod
     def _joincond_trim_constraints(
-        cls, a, b, constraints, consider_as_foreign_keys
-    ):
+        cls,
+        a: FromClause,
+        b: FromClause,
+        constraints: Dict[Any, Any],
+        consider_as_foreign_keys: Optional[Any],
+    ) -> None:
         # more than one constraint matched.  narrow down the list
         # to include just those FKCs that match exactly to
         # "consider_as_foreign_keys".
@@ -1508,7 +1526,9 @@ class Join(roles.DMLTableRole, FromClause):
         return Select(self.left, self.right).select_from(self)
 
     @util.preload_module("sqlalchemy.sql.util")
-    def _anonymous_fromclause(self, name=None, flat=False):
+    def _anonymous_fromclause(
+        self, name: Optional[str] = None, flat: bool = False
+    ) -> TODO_Any:
         sqlutil = util.preloaded.sql_util
         if flat:
             if name is not None:
@@ -1548,7 +1568,7 @@ class Join(roles.DMLTableRole, FromClause):
 
 
 class NoInit:
-    def __init__(self, *arg, **kw):
+    def __init__(self, *arg: Any, **kw: Any):
         raise NotImplementedError(
             "The %s class is not intended to be constructed "
             "directly.  Please use the %s() standalone "
@@ -1597,13 +1617,17 @@ class AliasedReturnsRows(NoInit, NamedFromClause):
 
     @classmethod
     def _construct(
-        cls: Type[_SelfAliasedReturnsRows], *arg: Any, **kw: Any
+        cls: Type[_SelfAliasedReturnsRows],
+        selectable: Any,
+        *,
+        name: Optional[str] = None,
+        **kw: Any,
     ) -> _SelfAliasedReturnsRows:
         obj = cls.__new__(cls)
-        obj._init(*arg, **kw)
+        obj._init(selectable, name=name, **kw)
         return obj
 
-    def _init(self, selectable, name=None):
+    def _init(self, selectable: Any, *, name: Optional[str] = None) -> None:
         self.element = coercions.expect(
             roles.ReturnsRowsRole, selectable, apply_propagate_attrs=self
         )
@@ -1624,7 +1648,7 @@ class AliasedReturnsRows(NoInit, NamedFromClause):
         super()._refresh_for_new_column(column)
         self.element._refresh_for_new_column(column)
 
-    def _populate_column_collection(self):
+    def _populate_column_collection(self) -> None:
         self.element._generate_fromclause_column_proxies(self)
 
     @util.ro_non_memoized_property
@@ -1636,11 +1660,11 @@ class AliasedReturnsRows(NoInit, NamedFromClause):
         return name
 
     @util.ro_non_memoized_property
-    def implicit_returning(self):
+    def implicit_returning(self) -> bool:
         return self.element.implicit_returning  # type: ignore
 
     @property
-    def original(self):
+    def original(self) -> ReturnsRows:
         """Legacy for dialects that are referring to Alias.original."""
         return self.element
 
@@ -1747,11 +1771,12 @@ class TableValuedAlias(LateralFromClause, Alias):
 
     def _init(
         self,
-        selectable,
-        name=None,
-        table_value_type=None,
-        joins_implicitly=False,
-    ):
+        selectable: Any,
+        *,
+        name: Optional[str] = None,
+        table_value_type: Optional[TableValueType] = None,
+        joins_implicitly: bool = False,
+    ) -> None:
         super()._init(selectable, name=name)
 
         self.joins_implicitly = joins_implicitly
@@ -1762,7 +1787,7 @@ class TableValuedAlias(LateralFromClause, Alias):
         )
 
     @HasMemoized.memoized_attribute
-    def column(self):
+    def column(self) -> TableValuedColumn[Any]:
         """Return a column expression representing this
         :class:`_sql.TableValuedAlias`.
 
@@ -1793,7 +1818,7 @@ class TableValuedAlias(LateralFromClause, Alias):
 
         """
 
-        tva = TableValuedAlias._construct(
+        tva: TableValuedAlias = TableValuedAlias._construct(
             self,
             name=name,
             table_value_type=self._tableval_type,
@@ -1819,7 +1844,11 @@ class TableValuedAlias(LateralFromClause, Alias):
         tva._is_lateral = True
         return tva
 
-    def render_derived(self, name=None, with_types=False):
+    def render_derived(
+        self,
+        name: Optional[str] = None,
+        with_types: bool = False,
+    ) -> TableValuedAlias:
         """Apply "render derived" to this :class:`_sql.TableValuedAlias`.
 
         This has the effect of the individual column names listed out
@@ -1867,7 +1896,7 @@ class TableValuedAlias(LateralFromClause, Alias):
 
         # construct against original to prevent memory growth
         # for repeated generations
-        new_alias = TableValuedAlias._construct(
+        new_alias: TableValuedAlias = TableValuedAlias._construct(
             self.element,
             name=name,
             table_value_type=self._tableval_type,
@@ -1952,16 +1981,24 @@ class TableSample(FromClauseAlias):
         )
 
     @util.preload_module("sqlalchemy.sql.functions")
-    def _init(self, selectable, sampling, name=None, seed=None):
+    def _init(  # type: ignore[override]
+        self,
+        selectable: Any,
+        *,
+        name: Optional[str] = None,
+        sampling: Union[float, Function[Any]],
+        seed: Optional[roles.ExpressionElementRole[Any]] = None,
+    ) -> None:
+        assert sampling is not None
         functions = util.preloaded.sql_functions
         if not isinstance(sampling, functions.Function):
             sampling = functions.func.system(sampling)
 
-        self.sampling = sampling
+        self.sampling: Function[Any] = sampling
         self.seed = seed
         super()._init(selectable, name=name)
 
-    def _get_method(self):
+    def _get_method(self) -> Function[Any]:
         return self.sampling
 
 
@@ -2026,15 +2063,16 @@ class CTE(
 
     def _init(
         self,
-        selectable,
-        name=None,
-        recursive=False,
-        nesting=False,
-        _cte_alias=None,
-        _restates=None,
-        _prefixes=None,
-        _suffixes=None,
-    ):
+        selectable: Select[Any],
+        *,
+        name: Optional[str] = None,
+        recursive: bool = False,
+        nesting: bool = False,
+        _cte_alias: Optional[CTE] = None,
+        _restates: Optional[CTE] = None,
+        _prefixes: Optional[Tuple[()]] = None,
+        _suffixes: Optional[Tuple[()]] = None,
+    ) -> None:
         self.recursive = recursive
         self.nesting = nesting
         self._cte_alias = _cte_alias
@@ -2046,7 +2084,7 @@ class CTE(
             self._suffixes = _suffixes
         super()._init(selectable, name=name)
 
-    def _populate_column_collection(self):
+    def _populate_column_collection(self) -> None:
         if self._cte_alias is not None:
             self._cte_alias._generate_fromclause_column_proxies(self)
         else:
@@ -2135,7 +2173,7 @@ class CTE(
             _suffixes=self._suffixes,
         )
 
-    def _get_reference_cte(self):
+    def _get_reference_cte(self) -> CTE:
         """
         A recursive CTE is updated to attach the recursive part.
         Updated CTEs should still refer to the original CTE.
@@ -2816,7 +2854,7 @@ class Subquery(AliasedReturnsRows):
         "construct before constructing a subquery object, or with the ORM "
         "use the :meth:`_query.Query.scalar_subquery` method.",
     )
-    def as_scalar(self):
+    def as_scalar(self) -> ScalarSelect[Any]:
         return self.element.set_label_style(LABEL_STYLE_NONE).scalar_subquery()
 
 
@@ -2832,23 +2870,25 @@ class FromGrouping(GroupedElement, FromClause):
     def __init__(self, element: FromClause):
         self.element = coercions.expect(roles.FromClauseRole, element)
 
-    def _init_collections(self):
+    def _init_collections(self) -> None:
         pass
 
     @util.ro_non_memoized_property
-    def columns(self):
+    def columns(
+        self,
+    ) -> ReadOnlyColumnCollection[str, KeyedColumnElement[Any]]:
         return self.element.columns
 
     @util.ro_non_memoized_property
-    def c(self):
+    def c(self) -> ReadOnlyColumnCollection[str, KeyedColumnElement[Any]]:
         return self.element.columns
 
     @property
-    def primary_key(self):
+    def primary_key(self) -> Iterable[NamedColumn[Any]]:
         return self.element.primary_key
 
     @property
-    def foreign_keys(self):
+    def foreign_keys(self) -> Iterable[ForeignKey]:
         return self.element.foreign_keys
 
     def is_derived_from(self, fromclause: Optional[FromClause]) -> bool:
@@ -2859,7 +2899,7 @@ class FromGrouping(GroupedElement, FromClause):
     ) -> NamedFromGrouping:
         return NamedFromGrouping(self.element.alias(name=name, flat=flat))
 
-    def _anonymous_fromclause(self, **kw):
+    def _anonymous_fromclause(self, **kw: Any) -> FromGrouping:
         return FromGrouping(self.element._anonymous_fromclause(**kw))
 
     @util.ro_non_memoized_property
@@ -2870,10 +2910,10 @@ class FromGrouping(GroupedElement, FromClause):
     def _from_objects(self) -> List[FromClause]:
         return self.element._from_objects
 
-    def __getstate__(self):
+    def __getstate__(self) -> Dict[str, FromClause]:
         return {"element": self.element}
 
-    def __setstate__(self, state):
+    def __setstate__(self, state: Dict[str, FromClause]) -> None:
         self.element = state["element"]
 
 
@@ -3074,7 +3114,7 @@ class ForUpdateArg(ClauseElement):
         else:
             return ForUpdateArg(**cast("Dict[str, Any]", with_for_update))
 
-    def __eq__(self, other):
+    def __eq__(self, other: Any) -> bool:
         return (
             isinstance(other, ForUpdateArg)
             and other.nowait == self.nowait
@@ -3084,10 +3124,10 @@ class ForUpdateArg(ClauseElement):
             and other.of is self.of
         )
 
-    def __ne__(self, other):
+    def __ne__(self, other: Any) -> bool:
         return not self.__eq__(other)
 
-    def __hash__(self):
+    def __hash__(self) -> int:
         return id(self)
 
     def __init__(
@@ -3166,7 +3206,7 @@ class Values(roles.InElementRole, Generative, LateralFromClause):
         self.named_with_column = not self._unnamed
 
     @property
-    def _column_types(self):
+    def _column_types(self) -> List[TypeEngine[Any]]:
         return [col.type for col in self._column_args]
 
     @_generative
@@ -3293,10 +3333,10 @@ class ScalarValues(roles.InElementRole, GroupedElement, ColumnElement[Any]):
         self.literal_binds = literal_binds
 
     @property
-    def _column_types(self):
+    def _column_types(self) -> List[TypeEngine[Any]]:
         return [col.type for col in self._column_args]
 
-    def __clause_element__(self):
+    def __clause_element__(self) -> ScalarValues:
         return self
 
 
@@ -3365,6 +3405,7 @@ class SelectBase(
     def _generate_fromclause_column_proxies(
         self,
         subquery: FromClause,
+        *,
         proxy_compound_columns: Optional[
             Iterable[Sequence[ColumnElement[Any]]]
         ] = None,
@@ -3427,11 +3468,13 @@ class SelectBase(
         "from, use the :attr:`_expression.SelectBase.selected_columns` "
         "attribute.",
     )
-    def c(self):
+    def c(self) -> ReadOnlyColumnCollection[str, KeyedColumnElement[Any]]:
         return self._implicit_subquery.columns
 
     @property
-    def columns(self):
+    def columns(
+        self,
+    ) -> ReadOnlyColumnCollection[str, KeyedColumnElement[Any]]:
         return self.c
 
     def get_label_style(self) -> SelectLabelStyle:
@@ -3463,11 +3506,11 @@ class SelectBase(
         "first in order to create "
         "a subquery, which then can be selected.",
     )
-    def select(self, *arg, **kw):
+    def select(self, *arg: Any, **kw: Any) -> Select[Any]:
         return self._implicit_subquery.select(*arg, **kw)
 
     @HasMemoized.memoized_attribute
-    def _implicit_subquery(self):
+    def _implicit_subquery(self) -> Subquery:
         return self.subquery()
 
     def _scalar_type(self) -> TypeEngine[Any]:
@@ -3480,7 +3523,7 @@ class SelectBase(
         "removed in a future release.  Please refer to "
         ":meth:`_expression.SelectBase.scalar_subquery`.",
     )
-    def as_scalar(self):
+    def as_scalar(self) -> ScalarSelect[Any]:
         return self.scalar_subquery()
 
     def exists(self) -> Exists:
@@ -3595,9 +3638,11 @@ class SelectBase(
 
         """
 
-        return Subquery._construct(self._ensure_disambiguated_names(), name)
+        return Subquery._construct(
+            self._ensure_disambiguated_names(), name=name
+        )
 
-    def _ensure_disambiguated_names(self):
+    def _ensure_disambiguated_names(self: SelfSelectBase) -> SelfSelectBase:
         """Ensure that the names generated by this selectbase will be
         disambiguated in some way, if possible.
 
@@ -3625,7 +3670,10 @@ class SelectBase(
         return self.subquery(name=name)
 
 
-class SelectStatementGrouping(GroupedElement, SelectBase):
+_SB = TypeVar("_SB", bound=SelectBase)
+
+
+class SelectStatementGrouping(GroupedElement, SelectBase, Generic[_SB]):
     """Represent a grouping of a :class:`_expression.SelectBase`.
 
     This differs from :class:`.Subquery` in that we are still
@@ -3641,12 +3689,14 @@ class SelectStatementGrouping(GroupedElement, SelectBase):
 
     _is_select_container = True
 
-    element: SelectBase
+    element: _SB
 
-    def __init__(self, element):
-        self.element = coercions.expect(roles.SelectStatementRole, element)
+    def __init__(self, element: _SB) -> None:
+        self.element = cast(
+            _SB, coercions.expect(roles.SelectStatementRole, element)
+        )
 
-    def _ensure_disambiguated_names(self):
+    def _ensure_disambiguated_names(self) -> SelectStatementGrouping[_SB]:
         new_element = self.element._ensure_disambiguated_names()
         if new_element is not self.element:
             return SelectStatementGrouping(new_element)
@@ -3658,19 +3708,24 @@ class SelectStatementGrouping(GroupedElement, SelectBase):
 
     def set_label_style(
         self, label_style: SelectLabelStyle
-    ) -> SelectStatementGrouping:
+    ) -> SelectStatementGrouping[_SB]:
         return SelectStatementGrouping(
             self.element.set_label_style(label_style)
         )
 
     @property
-    def select_statement(self):
+    def select_statement(self) -> _SB:
         return self.element
 
     def self_group(self: Self, against: Optional[OperatorType] = None) -> Self:
         ...
         return self
 
+    if TYPE_CHECKING:
+
+        def _ungroup(self) -> _SB:
+            ...
+
     # def _generate_columns_plus_names(
     #    self, anon_for_dupe_key: bool
     # ) -> List[Tuple[str, str, str, ColumnElement[Any], bool]]:
@@ -3679,6 +3734,7 @@ class SelectStatementGrouping(GroupedElement, SelectBase):
     def _generate_fromclause_column_proxies(
         self,
         subquery: FromClause,
+        *,
         proxy_compound_columns: Optional[
             Iterable[Sequence[ColumnElement[Any]]]
         ] = None,
@@ -4233,7 +4289,13 @@ class GenerativeSelect(SelectBase, Generative):
 @CompileState.plugin_for("default", "compound_select")
 class CompoundSelectState(CompileState):
     @util.memoized_property
-    def _label_resolve_dict(self):
+    def _label_resolve_dict(
+        self,
+    ) -> Tuple[
+        Dict[str, ColumnElement[Any]],
+        Dict[str, ColumnElement[Any]],
+        Dict[str, ColumnElement[Any]],
+    ]:
         # TODO: this is hacky and slow
         hacky_subquery = self.statement.subquery()
         hacky_subquery.named_with_column = False
@@ -4355,7 +4417,7 @@ class CompoundSelect(HasCompileState, GenerativeSelect, ExecutableReturnsRows):
                 return True
         return False
 
-    def set_label_style(self, style):
+    def set_label_style(self, style: SelectLabelStyle) -> CompoundSelect:
         if self._label_style is not style:
             self = self._generate()
             select_0 = self.selects[0].set_label_style(style)
@@ -4363,7 +4425,7 @@ class CompoundSelect(HasCompileState, GenerativeSelect, ExecutableReturnsRows):
 
         return self
 
-    def _ensure_disambiguated_names(self):
+    def _ensure_disambiguated_names(self) -> CompoundSelect:
         new_select = self.selects[0]._ensure_disambiguated_names()
         if new_select is not self.selects[0]:
             self = self._generate()
@@ -4374,6 +4436,7 @@ class CompoundSelect(HasCompileState, GenerativeSelect, ExecutableReturnsRows):
     def _generate_fromclause_column_proxies(
         self,
         subquery: FromClause,
+        *,
         proxy_compound_columns: Optional[
             Iterable[Sequence[ColumnElement[Any]]]
         ] = None,
@@ -4417,7 +4480,7 @@ class CompoundSelect(HasCompileState, GenerativeSelect, ExecutableReturnsRows):
             subquery, proxy_compound_columns=extra_col_iterator
         )
 
-    def _refresh_for_new_column(self, column):
+    def _refresh_for_new_column(self, column: ColumnElement[Any]) -> None:
         super()._refresh_for_new_column(column)
         for select in self.selects:
             select._refresh_for_new_column(column)
@@ -4548,11 +4611,16 @@ class SelectState(util.MemoizedSlots, CompileState):
         pa = prefix_anon_map()
         names = set()
 
-        def go(c, col_name=None):
-            if c._is_text_clause:
+        def go(
+            c: Union[ColumnElement[Any], TextClause],
+            col_name: Optional[str] = None,
+        ) -> Optional[str]:
+            if is_text_clause(c):
                 return None
+            elif TYPE_CHECKING:
+                assert is_column_element(c)
 
-            elif not dedupe:
+            if not dedupe:
                 name = c._proxy_key
                 if name is None:
                     name = "_no_label"
@@ -5153,7 +5221,11 @@ class Select(
 
         return self.where(*criteria)
 
-    def _filter_by_zero(self):
+    def _filter_by_zero(
+        self,
+    ) -> Union[
+        FromClause, _JoinTargetProtocol, ColumnElement[Any], TextClause
+    ]:
         if self._setup_joins:
             meth = SelectState.get_plugin_class(
                 self
@@ -5628,10 +5700,14 @@ class Select(
 
         # 3. clone everything else, making sure we use columns
         # corresponding to the froms we just made.
-        def replace(obj, **kw):
+        def replace(
+            obj: Union[BinaryExpression[Any], ColumnClause[Any]],
+            **kw: Any,
+        ) -> Optional[KeyedColumnElement[ColumnElement[Any]]]:
             if isinstance(obj, ColumnClause) and obj.table in new_froms:
                 newelem = new_froms[obj.table].corresponding_column(obj)
                 return newelem
+            return None
 
         kw["replace"] = replace
 
@@ -6241,7 +6317,7 @@ class Select(
         meth = SelectState.get_plugin_class(self).all_selected_columns
         return list(meth(self))
 
-    def _ensure_disambiguated_names(self):
+    def _ensure_disambiguated_names(self) -> Select[Any]:
         if self._label_style is LABEL_STYLE_NONE:
             self = self.set_label_style(LABEL_STYLE_DISAMBIGUATE_ONLY)
         return self
@@ -6249,6 +6325,7 @@ class Select(
     def _generate_fromclause_column_proxies(
         self,
         subquery: FromClause,
+        *,
         proxy_compound_columns: Optional[
             Iterable[Sequence[ColumnElement[Any]]]
         ] = None,
@@ -6304,14 +6381,14 @@ class Select(
 
         subquery._columns._populate_separate_keys(prox)
 
-    def _needs_parens_for_grouping(self):
+    def _needs_parens_for_grouping(self) -> bool:
         return self._has_row_limiting_clause or bool(
             self._order_by_clause.clauses
         )
 
     def self_group(
         self: Self, against: Optional[OperatorType] = None
-    ) -> Union[SelectStatementGrouping, Self]:
+    ) -> Union[SelectStatementGrouping[Self], Self]:
         ...
         """Return a 'grouping' construct as per the
         :class:`_expression.ClauseElement` specification.
@@ -6475,22 +6552,22 @@ class ScalarSelect(
 
     element: SelectBase
 
-    def __init__(self, element):
+    def __init__(self, element: SelectBase) -> None:
         self.element = element
         self.type = element._scalar_type()
 
-    def __getattr__(self, attr):
+    def __getattr__(self, attr: str) -> Any:
         return getattr(self.element, attr)
 
-    def __getstate__(self):
+    def __getstate__(self) -> Dict[str, Any]:
         return {"element": self.element, "type": self.type}
 
-    def __setstate__(self, state):
+    def __setstate__(self, state: Dict[str, Any]) -> None:
         self.element = state["element"]
         self.type = state["type"]
 
     @property
-    def columns(self):
+    def columns(self) -> NoReturn:
         raise exc.InvalidRequestError(
             "Scalar Select expression has no "
             "columns; use this object directly "
@@ -6528,6 +6605,11 @@ class ScalarSelect(
 
         return self
 
+    if TYPE_CHECKING:
+
+        def _ungroup(self) -> Select[Any]:
+            ...
+
     @_generative
     def correlate(
         self: SelfScalarSelect,
@@ -6617,6 +6699,7 @@ class Exists(UnaryExpression[bool]):
     """
 
     inherit_cache = True
+    element: Union[SelectStatementGrouping[Select[Any]], ScalarSelect[Any]]
 
     def __init__(
         self,
@@ -6649,10 +6732,15 @@ class Exists(UnaryExpression[bool]):
     def _from_objects(self) -> List[FromClause]:
         return []
 
-    def _regroup(self, fn):
+    def _regroup(
+        self, fn: Callable[[Select[Any]], Select[Any]]
+    ) -> SelectStatementGrouping[Select[Any]]:
         element = self.element._ungroup()
-        element = fn(element)
-        return element.self_group(against=operators.exists)
+        new_element = fn(element)
+
+        return_value = new_element.self_group(against=operators.exists)
+        assert isinstance(return_value, SelectStatementGrouping)
+        return return_value
 
     def select(self) -> Select[Any]:
         r"""Return a SELECT of this :class:`_expression.Exists`.
@@ -6792,7 +6880,12 @@ class TextualSelect(SelectBase, Executable, Generative):
     is_text = True
     is_select = True
 
-    def __init__(self, text, columns, positional=False):
+    def __init__(
+        self,
+        text: TextClause,
+        columns: List[ColumnClause[Any]],
+        positional: bool = False,
+    ) -> None:
         self.element = text
         # convert for ORM attributes->columns, etc
         self.column_args = [
@@ -6832,10 +6925,10 @@ class TextualSelect(SelectBase, Executable, Generative):
     def _all_selected_columns(self) -> _SelectIterable:
         return self.column_args
 
-    def set_label_style(self, style):
+    def set_label_style(self, style: SelectLabelStyle) -> TextualSelect:
         return self
 
-    def _ensure_disambiguated_names(self):
+    def _ensure_disambiguated_names(self) -> TextualSelect:
         return self
 
     @_generative
@@ -6848,8 +6941,16 @@ class TextualSelect(SelectBase, Executable, Generative):
         return self
 
     def _generate_fromclause_column_proxies(
-        self, fromclause, proxy_compound_columns=None
-    ):
+        self,
+        fromclause: FromClause,
+        *,
+        proxy_compound_columns: Optional[
+            Iterable[Sequence[ColumnElement[Any]]]
+        ] = None,
+    ) -> None:
+        if TYPE_CHECKING:
+            assert isinstance(fromclause, Subquery)
+
         if proxy_compound_columns:
             fromclause._columns._populate_separate_keys(
                 c._make_proxy(fromclause, compound_select_cols=extra_cols)
@@ -6862,7 +6963,7 @@ class TextualSelect(SelectBase, Executable, Generative):
                 c._make_proxy(fromclause) for c in self.column_args
             )
 
-    def _scalar_type(self):
+    def _scalar_type(self) -> Union[TypeEngine[Any], Any]:
         return self.column_args[0].type
 
 
@@ -6871,7 +6972,7 @@ TextAsFrom = TextualSelect
 
 
 class AnnotatedFromClause(Annotated):
-    def _copy_internals(self, **kw):
+    def _copy_internals(self, **kw: Any) -> None:
         super()._copy_internals(**kw)
         if kw.get("ind_cols_on_fromclause", False):
             ee = self._Annotated__element  # type: ignore