From 73389e94aefd4604c9f1b1ba744e0e74c99d8371 Mon Sep 17 00:00:00 2001 From: Dzmitar <17720985+dzmitar@users.noreply.github.com> Date: Thu, 12 Jan 2023 10:47:24 -0500 Subject: [PATCH] Type annotations for sqlalchemy.sql.selectable Co-authored-by: Mike Bayer Closes: #9028 Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/9028 Pull-request-sha: e2f8ddeac0b08feaad917285e988acf1e9465a26 Change-Id: I5caad31bfeeed2d224657a55f067ba1d86b8733f --- lib/sqlalchemy/sql/_typing.py | 2 + lib/sqlalchemy/sql/base.py | 9 +- lib/sqlalchemy/sql/coercions.py | 12 +- lib/sqlalchemy/sql/elements.py | 8 +- lib/sqlalchemy/sql/functions.py | 2 +- lib/sqlalchemy/sql/roles.py | 2 +- lib/sqlalchemy/sql/selectable.py | 297 +++++++++++++++++++++---------- 7 files changed, 224 insertions(+), 108 deletions(-) diff --git a/lib/sqlalchemy/sql/_typing.py b/lib/sqlalchemy/sql/_typing.py index 78e196efc2..a120629caa 100644 --- a/lib/sqlalchemy/sql/_typing.py +++ b/lib/sqlalchemy/sql/_typing.py @@ -70,6 +70,8 @@ _T = TypeVar("_T", bound=Any) _CE = TypeVar("_CE", bound="ColumnElement[Any]") +_CLE = TypeVar("_CLE", bound="ClauseElement") + class _HasClauseElement(Protocol): """indicates a class that has a __clause_element__() method""" diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py index 25e214bd39..96ebc78249 100644 --- a/lib/sqlalchemy/sql/base.py +++ b/lib/sqlalchemy/sql/base.py @@ -66,6 +66,7 @@ if TYPE_CHECKING: from . import type_api from ._orm_types import DMLStrategyArgument from ._orm_types import SynchronizeSessionArgument + from ._typing import _CLE from .elements import BindParameter from .elements import ClauseList from .elements import ColumnClause # noqa @@ -282,7 +283,9 @@ def _clone(element, **kw): return element._clone(**kw) -def _expand_cloned(elements): +def _expand_cloned( + elements: Iterable[_CLE], +) -> Iterable[_CLE]: """expand the given set of ClauseElements to be the set of all 'cloned' predecessors. @@ -291,7 +294,7 @@ def _expand_cloned(elements): return itertools.chain(*[x._cloned_set for x in elements]) -def _cloned_intersection(a, b): +def _cloned_intersection(a: Iterable[_CLE], b: Iterable[_CLE]) -> Set[_CLE]: """return the intersection of sets a and b, counting any overlap between 'cloned' predecessors. @@ -302,7 +305,7 @@ def _cloned_intersection(a, b): return {elem for elem in a if all_overlap.intersection(elem._cloned_set)} -def _cloned_difference(a, b): +def _cloned_difference(a: Iterable[_CLE], b: Iterable[_CLE]) -> Set[_CLE]: all_overlap = set(_expand_cloned(a)).intersection(_expand_cloned(b)) return { elem for elem in a if not all_overlap.intersection(elem._cloned_set) diff --git a/lib/sqlalchemy/sql/coercions.py b/lib/sqlalchemy/sql/coercions.py index 30e8bd7740..9fe65b0cd6 100644 --- a/lib/sqlalchemy/sql/coercions.py +++ b/lib/sqlalchemy/sql/coercions.py @@ -14,6 +14,7 @@ import re import typing from typing import Any from typing import Callable +from typing import cast from typing import Dict from typing import Iterable from typing import Iterator @@ -21,6 +22,7 @@ from typing import List from typing import NoReturn from typing import Optional from typing import overload +from typing import Sequence from typing import Tuple from typing import Type from typing import TYPE_CHECKING @@ -140,7 +142,11 @@ def _document_text_coercion( ) -def _expression_collection_was_a_list(attrname, fnname, args): +def _expression_collection_was_a_list( + attrname: str, + fnname: str, + args: Union[Sequence[_T], Sequence[Sequence[_T]]], +) -> Sequence[_T]: if args and isinstance(args[0], (list, set, dict)) and len(args) == 1: if isinstance(args[0], list): raise exc.ArgumentError( @@ -149,9 +155,9 @@ def _expression_collection_was_a_list(attrname, fnname, args): "of items, is now passed as a series of positional " "elements, rather than as a list. " ) - return args[0] + return cast("Sequence[_T]", args[0]) - return args + return cast("Sequence[_T]", args) @overload diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index 748e9504b3..f9c7cac23e 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -582,7 +582,9 @@ class ClauseElement( """ return traversals.compare(self, other, **kw) - def self_group(self, against: Optional[OperatorType] = None) -> Any: + def self_group( + self, against: Optional[OperatorType] = None + ) -> ClauseElement: """Apply a 'grouping' to this :class:`_expression.ClauseElement`. This method is overridden by subclasses to return a "grouping" @@ -609,7 +611,7 @@ class ClauseElement( """ return self - def _ungroup(self): + def _ungroup(self) -> ClauseElement: """Return this :class:`_expression.ClauseElement` without any groupings. """ @@ -3452,6 +3454,8 @@ class UnaryExpression(ColumnElement[_T]): ("modifier", InternalTraversal.dp_operator), ] + element: ClauseElement + def __init__( self, element: ColumnElement[Any], diff --git a/lib/sqlalchemy/sql/functions.py b/lib/sqlalchemy/sql/functions.py index 902811037e..ca30ab5ead 100644 --- a/lib/sqlalchemy/sql/functions.py +++ b/lib/sqlalchemy/sql/functions.py @@ -628,7 +628,7 @@ class FunctionElement(Executable, ColumnElement[_T], FromClause, Generative): return TableValuedAlias._construct( self, - name, + name=name, table_value_type=self.type, joins_implicitly=joins_implicitly, ) diff --git a/lib/sqlalchemy/sql/roles.py b/lib/sqlalchemy/sql/roles.py index 35d1708e24..f8aac70b99 100644 --- a/lib/sqlalchemy/sql/roles.py +++ b/lib/sqlalchemy/sql/roles.py @@ -225,7 +225,7 @@ class AnonymizedFromClauseRole(StrictFromClauseRole): if TYPE_CHECKING: def _anonymous_fromclause( - self, name: Optional[str] = None, flat: bool = False + self, *, name: Optional[str] = None, flat: bool = False ) -> FromClause: ... diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py index 54230d58a6..511c485514 100644 --- a/lib/sqlalchemy/sql/selectable.py +++ b/lib/sqlalchemy/sql/selectable.py @@ -4,7 +4,6 @@ # # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php -# mypy: allow-untyped-defs, allow-untyped-calls """The :class:`_expression.FromClause` class of SQL expression elements, representing @@ -54,6 +53,7 @@ from ._typing import is_column_element from ._typing import is_select_statement from ._typing import is_subquery from ._typing import is_table +from ._typing import is_text_clause from .annotation import Annotated from .annotation import SupportsCloneAnnotations from .base import _clone @@ -131,6 +131,7 @@ if TYPE_CHECKING: from .compiler import SQLCompiler from .dml import Delete from .dml import Update + from .elements import BinaryExpression from .elements import KeyedColumnElement from .elements import Label from .elements import NamedColumn @@ -138,6 +139,7 @@ if TYPE_CHECKING: from .functions import Function from .schema import ForeignKey from .schema import ForeignKeyConstraint + from .sqltypes import TableValueType from .type_api import TypeEngine from .visitors import _CloneCallableType @@ -153,6 +155,10 @@ class _JoinTargetProtocol(Protocol): def _from_objects(self) -> List[FromClause]: ... + @util.ro_non_memoized_property + def entity_namespace(self) -> _EntityNamespace: + ... + _JoinTargetElement = Union["FromClause", _JoinTargetProtocol] _OnClauseElement = Union["ColumnElement[bool]", _JoinTargetProtocol] @@ -295,7 +301,7 @@ class Selectable(ReturnsRows): :ref:`tutorial_lateral_correlation` - overview of usage. """ - return Lateral._construct(self, name) + return Lateral._construct(self, name=name) @util.deprecated( "1.4", @@ -757,7 +763,7 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable): """ - return Alias._construct(self, name) + return Alias._construct(self, name=name) def tablesample( self, @@ -778,7 +784,9 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable): :func:`_expression.tablesample` - usage guidelines and parameters """ - return TableSample._construct(self, sampling, name, seed) + return TableSample._construct( + self, sampling=sampling, name=name, seed=seed + ) def is_derived_from(self, fromclause: Optional[FromClause]) -> bool: """Return ``True`` if this :class:`_expression.FromClause` is @@ -991,8 +999,8 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable): self._reset_column_collection() def _anonymous_fromclause( - self, name: Optional[str] = None, flat: bool = False - ) -> NamedFromClause: + self, *, name: Optional[str] = None, flat: bool = False + ) -> FromClause: return self.alias(name=name) if TYPE_CHECKING: @@ -1252,7 +1260,7 @@ class Join(roles.DMLTableRole, FromClause): return FromGrouping(self) @util.preload_module("sqlalchemy.sql.util") - def _populate_column_collection(self): + def _populate_column_collection(self) -> None: sqlutil = util.preloaded.sql_util columns: List[KeyedColumnElement[Any]] = [c for c in self.left.c] + [ c for c in self.right.c @@ -1291,10 +1299,14 @@ class Join(roles.DMLTableRole, FromClause): # set up a special replace function that will replace for # ColumnClause with parent table referring to those # replaced FromClause objects - def replace(obj, **kw): + def replace( + obj: Union[BinaryExpression[Any], ColumnClause[Any]], + **kw: Any, + ) -> Optional[KeyedColumnElement[ColumnElement[Any]]]: if isinstance(obj, ColumnClause) and obj.table in new_froms: newelem = new_froms[obj.table].corresponding_column(obj) return newelem + return None kw["replace"] = replace @@ -1311,7 +1323,9 @@ class Join(roles.DMLTableRole, FromClause): self.right._refresh_for_new_column(column) def _match_primaries( - self, left: FromClause, right: FromClause + self, + left: FromClause, + right: FromClause, ) -> ColumnElement[bool]: if isinstance(left, Join): left_right = left.right @@ -1460,8 +1474,12 @@ class Join(roles.DMLTableRole, FromClause): @classmethod def _joincond_trim_constraints( - cls, a, b, constraints, consider_as_foreign_keys - ): + cls, + a: FromClause, + b: FromClause, + constraints: Dict[Any, Any], + consider_as_foreign_keys: Optional[Any], + ) -> None: # more than one constraint matched. narrow down the list # to include just those FKCs that match exactly to # "consider_as_foreign_keys". @@ -1508,7 +1526,9 @@ class Join(roles.DMLTableRole, FromClause): return Select(self.left, self.right).select_from(self) @util.preload_module("sqlalchemy.sql.util") - def _anonymous_fromclause(self, name=None, flat=False): + def _anonymous_fromclause( + self, name: Optional[str] = None, flat: bool = False + ) -> TODO_Any: sqlutil = util.preloaded.sql_util if flat: if name is not None: @@ -1548,7 +1568,7 @@ class Join(roles.DMLTableRole, FromClause): class NoInit: - def __init__(self, *arg, **kw): + def __init__(self, *arg: Any, **kw: Any): raise NotImplementedError( "The %s class is not intended to be constructed " "directly. Please use the %s() standalone " @@ -1597,13 +1617,17 @@ class AliasedReturnsRows(NoInit, NamedFromClause): @classmethod def _construct( - cls: Type[_SelfAliasedReturnsRows], *arg: Any, **kw: Any + cls: Type[_SelfAliasedReturnsRows], + selectable: Any, + *, + name: Optional[str] = None, + **kw: Any, ) -> _SelfAliasedReturnsRows: obj = cls.__new__(cls) - obj._init(*arg, **kw) + obj._init(selectable, name=name, **kw) return obj - def _init(self, selectable, name=None): + def _init(self, selectable: Any, *, name: Optional[str] = None) -> None: self.element = coercions.expect( roles.ReturnsRowsRole, selectable, apply_propagate_attrs=self ) @@ -1624,7 +1648,7 @@ class AliasedReturnsRows(NoInit, NamedFromClause): super()._refresh_for_new_column(column) self.element._refresh_for_new_column(column) - def _populate_column_collection(self): + def _populate_column_collection(self) -> None: self.element._generate_fromclause_column_proxies(self) @util.ro_non_memoized_property @@ -1636,11 +1660,11 @@ class AliasedReturnsRows(NoInit, NamedFromClause): return name @util.ro_non_memoized_property - def implicit_returning(self): + def implicit_returning(self) -> bool: return self.element.implicit_returning # type: ignore @property - def original(self): + def original(self) -> ReturnsRows: """Legacy for dialects that are referring to Alias.original.""" return self.element @@ -1747,11 +1771,12 @@ class TableValuedAlias(LateralFromClause, Alias): def _init( self, - selectable, - name=None, - table_value_type=None, - joins_implicitly=False, - ): + selectable: Any, + *, + name: Optional[str] = None, + table_value_type: Optional[TableValueType] = None, + joins_implicitly: bool = False, + ) -> None: super()._init(selectable, name=name) self.joins_implicitly = joins_implicitly @@ -1762,7 +1787,7 @@ class TableValuedAlias(LateralFromClause, Alias): ) @HasMemoized.memoized_attribute - def column(self): + def column(self) -> TableValuedColumn[Any]: """Return a column expression representing this :class:`_sql.TableValuedAlias`. @@ -1793,7 +1818,7 @@ class TableValuedAlias(LateralFromClause, Alias): """ - tva = TableValuedAlias._construct( + tva: TableValuedAlias = TableValuedAlias._construct( self, name=name, table_value_type=self._tableval_type, @@ -1819,7 +1844,11 @@ class TableValuedAlias(LateralFromClause, Alias): tva._is_lateral = True return tva - def render_derived(self, name=None, with_types=False): + def render_derived( + self, + name: Optional[str] = None, + with_types: bool = False, + ) -> TableValuedAlias: """Apply "render derived" to this :class:`_sql.TableValuedAlias`. This has the effect of the individual column names listed out @@ -1867,7 +1896,7 @@ class TableValuedAlias(LateralFromClause, Alias): # construct against original to prevent memory growth # for repeated generations - new_alias = TableValuedAlias._construct( + new_alias: TableValuedAlias = TableValuedAlias._construct( self.element, name=name, table_value_type=self._tableval_type, @@ -1952,16 +1981,24 @@ class TableSample(FromClauseAlias): ) @util.preload_module("sqlalchemy.sql.functions") - def _init(self, selectable, sampling, name=None, seed=None): + def _init( # type: ignore[override] + self, + selectable: Any, + *, + name: Optional[str] = None, + sampling: Union[float, Function[Any]], + seed: Optional[roles.ExpressionElementRole[Any]] = None, + ) -> None: + assert sampling is not None functions = util.preloaded.sql_functions if not isinstance(sampling, functions.Function): sampling = functions.func.system(sampling) - self.sampling = sampling + self.sampling: Function[Any] = sampling self.seed = seed super()._init(selectable, name=name) - def _get_method(self): + def _get_method(self) -> Function[Any]: return self.sampling @@ -2026,15 +2063,16 @@ class CTE( def _init( self, - selectable, - name=None, - recursive=False, - nesting=False, - _cte_alias=None, - _restates=None, - _prefixes=None, - _suffixes=None, - ): + selectable: Select[Any], + *, + name: Optional[str] = None, + recursive: bool = False, + nesting: bool = False, + _cte_alias: Optional[CTE] = None, + _restates: Optional[CTE] = None, + _prefixes: Optional[Tuple[()]] = None, + _suffixes: Optional[Tuple[()]] = None, + ) -> None: self.recursive = recursive self.nesting = nesting self._cte_alias = _cte_alias @@ -2046,7 +2084,7 @@ class CTE( self._suffixes = _suffixes super()._init(selectable, name=name) - def _populate_column_collection(self): + def _populate_column_collection(self) -> None: if self._cte_alias is not None: self._cte_alias._generate_fromclause_column_proxies(self) else: @@ -2135,7 +2173,7 @@ class CTE( _suffixes=self._suffixes, ) - def _get_reference_cte(self): + def _get_reference_cte(self) -> CTE: """ A recursive CTE is updated to attach the recursive part. Updated CTEs should still refer to the original CTE. @@ -2816,7 +2854,7 @@ class Subquery(AliasedReturnsRows): "construct before constructing a subquery object, or with the ORM " "use the :meth:`_query.Query.scalar_subquery` method.", ) - def as_scalar(self): + def as_scalar(self) -> ScalarSelect[Any]: return self.element.set_label_style(LABEL_STYLE_NONE).scalar_subquery() @@ -2832,23 +2870,25 @@ class FromGrouping(GroupedElement, FromClause): def __init__(self, element: FromClause): self.element = coercions.expect(roles.FromClauseRole, element) - def _init_collections(self): + def _init_collections(self) -> None: pass @util.ro_non_memoized_property - def columns(self): + def columns( + self, + ) -> ReadOnlyColumnCollection[str, KeyedColumnElement[Any]]: return self.element.columns @util.ro_non_memoized_property - def c(self): + def c(self) -> ReadOnlyColumnCollection[str, KeyedColumnElement[Any]]: return self.element.columns @property - def primary_key(self): + def primary_key(self) -> Iterable[NamedColumn[Any]]: return self.element.primary_key @property - def foreign_keys(self): + def foreign_keys(self) -> Iterable[ForeignKey]: return self.element.foreign_keys def is_derived_from(self, fromclause: Optional[FromClause]) -> bool: @@ -2859,7 +2899,7 @@ class FromGrouping(GroupedElement, FromClause): ) -> NamedFromGrouping: return NamedFromGrouping(self.element.alias(name=name, flat=flat)) - def _anonymous_fromclause(self, **kw): + def _anonymous_fromclause(self, **kw: Any) -> FromGrouping: return FromGrouping(self.element._anonymous_fromclause(**kw)) @util.ro_non_memoized_property @@ -2870,10 +2910,10 @@ class FromGrouping(GroupedElement, FromClause): def _from_objects(self) -> List[FromClause]: return self.element._from_objects - def __getstate__(self): + def __getstate__(self) -> Dict[str, FromClause]: return {"element": self.element} - def __setstate__(self, state): + def __setstate__(self, state: Dict[str, FromClause]) -> None: self.element = state["element"] @@ -3074,7 +3114,7 @@ class ForUpdateArg(ClauseElement): else: return ForUpdateArg(**cast("Dict[str, Any]", with_for_update)) - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: return ( isinstance(other, ForUpdateArg) and other.nowait == self.nowait @@ -3084,10 +3124,10 @@ class ForUpdateArg(ClauseElement): and other.of is self.of ) - def __ne__(self, other): + def __ne__(self, other: Any) -> bool: return not self.__eq__(other) - def __hash__(self): + def __hash__(self) -> int: return id(self) def __init__( @@ -3166,7 +3206,7 @@ class Values(roles.InElementRole, Generative, LateralFromClause): self.named_with_column = not self._unnamed @property - def _column_types(self): + def _column_types(self) -> List[TypeEngine[Any]]: return [col.type for col in self._column_args] @_generative @@ -3293,10 +3333,10 @@ class ScalarValues(roles.InElementRole, GroupedElement, ColumnElement[Any]): self.literal_binds = literal_binds @property - def _column_types(self): + def _column_types(self) -> List[TypeEngine[Any]]: return [col.type for col in self._column_args] - def __clause_element__(self): + def __clause_element__(self) -> ScalarValues: return self @@ -3365,6 +3405,7 @@ class SelectBase( def _generate_fromclause_column_proxies( self, subquery: FromClause, + *, proxy_compound_columns: Optional[ Iterable[Sequence[ColumnElement[Any]]] ] = None, @@ -3427,11 +3468,13 @@ class SelectBase( "from, use the :attr:`_expression.SelectBase.selected_columns` " "attribute.", ) - def c(self): + def c(self) -> ReadOnlyColumnCollection[str, KeyedColumnElement[Any]]: return self._implicit_subquery.columns @property - def columns(self): + def columns( + self, + ) -> ReadOnlyColumnCollection[str, KeyedColumnElement[Any]]: return self.c def get_label_style(self) -> SelectLabelStyle: @@ -3463,11 +3506,11 @@ class SelectBase( "first in order to create " "a subquery, which then can be selected.", ) - def select(self, *arg, **kw): + def select(self, *arg: Any, **kw: Any) -> Select[Any]: return self._implicit_subquery.select(*arg, **kw) @HasMemoized.memoized_attribute - def _implicit_subquery(self): + def _implicit_subquery(self) -> Subquery: return self.subquery() def _scalar_type(self) -> TypeEngine[Any]: @@ -3480,7 +3523,7 @@ class SelectBase( "removed in a future release. Please refer to " ":meth:`_expression.SelectBase.scalar_subquery`.", ) - def as_scalar(self): + def as_scalar(self) -> ScalarSelect[Any]: return self.scalar_subquery() def exists(self) -> Exists: @@ -3595,9 +3638,11 @@ class SelectBase( """ - return Subquery._construct(self._ensure_disambiguated_names(), name) + return Subquery._construct( + self._ensure_disambiguated_names(), name=name + ) - def _ensure_disambiguated_names(self): + def _ensure_disambiguated_names(self: SelfSelectBase) -> SelfSelectBase: """Ensure that the names generated by this selectbase will be disambiguated in some way, if possible. @@ -3625,7 +3670,10 @@ class SelectBase( return self.subquery(name=name) -class SelectStatementGrouping(GroupedElement, SelectBase): +_SB = TypeVar("_SB", bound=SelectBase) + + +class SelectStatementGrouping(GroupedElement, SelectBase, Generic[_SB]): """Represent a grouping of a :class:`_expression.SelectBase`. This differs from :class:`.Subquery` in that we are still @@ -3641,12 +3689,14 @@ class SelectStatementGrouping(GroupedElement, SelectBase): _is_select_container = True - element: SelectBase + element: _SB - def __init__(self, element): - self.element = coercions.expect(roles.SelectStatementRole, element) + def __init__(self, element: _SB) -> None: + self.element = cast( + _SB, coercions.expect(roles.SelectStatementRole, element) + ) - def _ensure_disambiguated_names(self): + def _ensure_disambiguated_names(self) -> SelectStatementGrouping[_SB]: new_element = self.element._ensure_disambiguated_names() if new_element is not self.element: return SelectStatementGrouping(new_element) @@ -3658,19 +3708,24 @@ class SelectStatementGrouping(GroupedElement, SelectBase): def set_label_style( self, label_style: SelectLabelStyle - ) -> SelectStatementGrouping: + ) -> SelectStatementGrouping[_SB]: return SelectStatementGrouping( self.element.set_label_style(label_style) ) @property - def select_statement(self): + def select_statement(self) -> _SB: return self.element def self_group(self: Self, against: Optional[OperatorType] = None) -> Self: ... return self + if TYPE_CHECKING: + + def _ungroup(self) -> _SB: + ... + # def _generate_columns_plus_names( # self, anon_for_dupe_key: bool # ) -> List[Tuple[str, str, str, ColumnElement[Any], bool]]: @@ -3679,6 +3734,7 @@ class SelectStatementGrouping(GroupedElement, SelectBase): def _generate_fromclause_column_proxies( self, subquery: FromClause, + *, proxy_compound_columns: Optional[ Iterable[Sequence[ColumnElement[Any]]] ] = None, @@ -4233,7 +4289,13 @@ class GenerativeSelect(SelectBase, Generative): @CompileState.plugin_for("default", "compound_select") class CompoundSelectState(CompileState): @util.memoized_property - def _label_resolve_dict(self): + def _label_resolve_dict( + self, + ) -> Tuple[ + Dict[str, ColumnElement[Any]], + Dict[str, ColumnElement[Any]], + Dict[str, ColumnElement[Any]], + ]: # TODO: this is hacky and slow hacky_subquery = self.statement.subquery() hacky_subquery.named_with_column = False @@ -4355,7 +4417,7 @@ class CompoundSelect(HasCompileState, GenerativeSelect, ExecutableReturnsRows): return True return False - def set_label_style(self, style): + def set_label_style(self, style: SelectLabelStyle) -> CompoundSelect: if self._label_style is not style: self = self._generate() select_0 = self.selects[0].set_label_style(style) @@ -4363,7 +4425,7 @@ class CompoundSelect(HasCompileState, GenerativeSelect, ExecutableReturnsRows): return self - def _ensure_disambiguated_names(self): + def _ensure_disambiguated_names(self) -> CompoundSelect: new_select = self.selects[0]._ensure_disambiguated_names() if new_select is not self.selects[0]: self = self._generate() @@ -4374,6 +4436,7 @@ class CompoundSelect(HasCompileState, GenerativeSelect, ExecutableReturnsRows): def _generate_fromclause_column_proxies( self, subquery: FromClause, + *, proxy_compound_columns: Optional[ Iterable[Sequence[ColumnElement[Any]]] ] = None, @@ -4417,7 +4480,7 @@ class CompoundSelect(HasCompileState, GenerativeSelect, ExecutableReturnsRows): subquery, proxy_compound_columns=extra_col_iterator ) - def _refresh_for_new_column(self, column): + def _refresh_for_new_column(self, column: ColumnElement[Any]) -> None: super()._refresh_for_new_column(column) for select in self.selects: select._refresh_for_new_column(column) @@ -4548,11 +4611,16 @@ class SelectState(util.MemoizedSlots, CompileState): pa = prefix_anon_map() names = set() - def go(c, col_name=None): - if c._is_text_clause: + def go( + c: Union[ColumnElement[Any], TextClause], + col_name: Optional[str] = None, + ) -> Optional[str]: + if is_text_clause(c): return None + elif TYPE_CHECKING: + assert is_column_element(c) - elif not dedupe: + if not dedupe: name = c._proxy_key if name is None: name = "_no_label" @@ -5153,7 +5221,11 @@ class Select( return self.where(*criteria) - def _filter_by_zero(self): + def _filter_by_zero( + self, + ) -> Union[ + FromClause, _JoinTargetProtocol, ColumnElement[Any], TextClause + ]: if self._setup_joins: meth = SelectState.get_plugin_class( self @@ -5628,10 +5700,14 @@ class Select( # 3. clone everything else, making sure we use columns # corresponding to the froms we just made. - def replace(obj, **kw): + def replace( + obj: Union[BinaryExpression[Any], ColumnClause[Any]], + **kw: Any, + ) -> Optional[KeyedColumnElement[ColumnElement[Any]]]: if isinstance(obj, ColumnClause) and obj.table in new_froms: newelem = new_froms[obj.table].corresponding_column(obj) return newelem + return None kw["replace"] = replace @@ -6241,7 +6317,7 @@ class Select( meth = SelectState.get_plugin_class(self).all_selected_columns return list(meth(self)) - def _ensure_disambiguated_names(self): + def _ensure_disambiguated_names(self) -> Select[Any]: if self._label_style is LABEL_STYLE_NONE: self = self.set_label_style(LABEL_STYLE_DISAMBIGUATE_ONLY) return self @@ -6249,6 +6325,7 @@ class Select( def _generate_fromclause_column_proxies( self, subquery: FromClause, + *, proxy_compound_columns: Optional[ Iterable[Sequence[ColumnElement[Any]]] ] = None, @@ -6304,14 +6381,14 @@ class Select( subquery._columns._populate_separate_keys(prox) - def _needs_parens_for_grouping(self): + def _needs_parens_for_grouping(self) -> bool: return self._has_row_limiting_clause or bool( self._order_by_clause.clauses ) def self_group( self: Self, against: Optional[OperatorType] = None - ) -> Union[SelectStatementGrouping, Self]: + ) -> Union[SelectStatementGrouping[Self], Self]: ... """Return a 'grouping' construct as per the :class:`_expression.ClauseElement` specification. @@ -6475,22 +6552,22 @@ class ScalarSelect( element: SelectBase - def __init__(self, element): + def __init__(self, element: SelectBase) -> None: self.element = element self.type = element._scalar_type() - def __getattr__(self, attr): + def __getattr__(self, attr: str) -> Any: return getattr(self.element, attr) - def __getstate__(self): + def __getstate__(self) -> Dict[str, Any]: return {"element": self.element, "type": self.type} - def __setstate__(self, state): + def __setstate__(self, state: Dict[str, Any]) -> None: self.element = state["element"] self.type = state["type"] @property - def columns(self): + def columns(self) -> NoReturn: raise exc.InvalidRequestError( "Scalar Select expression has no " "columns; use this object directly " @@ -6528,6 +6605,11 @@ class ScalarSelect( return self + if TYPE_CHECKING: + + def _ungroup(self) -> Select[Any]: + ... + @_generative def correlate( self: SelfScalarSelect, @@ -6617,6 +6699,7 @@ class Exists(UnaryExpression[bool]): """ inherit_cache = True + element: Union[SelectStatementGrouping[Select[Any]], ScalarSelect[Any]] def __init__( self, @@ -6649,10 +6732,15 @@ class Exists(UnaryExpression[bool]): def _from_objects(self) -> List[FromClause]: return [] - def _regroup(self, fn): + def _regroup( + self, fn: Callable[[Select[Any]], Select[Any]] + ) -> SelectStatementGrouping[Select[Any]]: element = self.element._ungroup() - element = fn(element) - return element.self_group(against=operators.exists) + new_element = fn(element) + + return_value = new_element.self_group(against=operators.exists) + assert isinstance(return_value, SelectStatementGrouping) + return return_value def select(self) -> Select[Any]: r"""Return a SELECT of this :class:`_expression.Exists`. @@ -6792,7 +6880,12 @@ class TextualSelect(SelectBase, Executable, Generative): is_text = True is_select = True - def __init__(self, text, columns, positional=False): + def __init__( + self, + text: TextClause, + columns: List[ColumnClause[Any]], + positional: bool = False, + ) -> None: self.element = text # convert for ORM attributes->columns, etc self.column_args = [ @@ -6832,10 +6925,10 @@ class TextualSelect(SelectBase, Executable, Generative): def _all_selected_columns(self) -> _SelectIterable: return self.column_args - def set_label_style(self, style): + def set_label_style(self, style: SelectLabelStyle) -> TextualSelect: return self - def _ensure_disambiguated_names(self): + def _ensure_disambiguated_names(self) -> TextualSelect: return self @_generative @@ -6848,8 +6941,16 @@ class TextualSelect(SelectBase, Executable, Generative): return self def _generate_fromclause_column_proxies( - self, fromclause, proxy_compound_columns=None - ): + self, + fromclause: FromClause, + *, + proxy_compound_columns: Optional[ + Iterable[Sequence[ColumnElement[Any]]] + ] = None, + ) -> None: + if TYPE_CHECKING: + assert isinstance(fromclause, Subquery) + if proxy_compound_columns: fromclause._columns._populate_separate_keys( c._make_proxy(fromclause, compound_select_cols=extra_cols) @@ -6862,7 +6963,7 @@ class TextualSelect(SelectBase, Executable, Generative): c._make_proxy(fromclause) for c in self.column_args ) - def _scalar_type(self): + def _scalar_type(self) -> Union[TypeEngine[Any], Any]: return self.column_args[0].type @@ -6871,7 +6972,7 @@ TextAsFrom = TextualSelect class AnnotatedFromClause(Annotated): - def _copy_internals(self, **kw): + def _copy_internals(self, **kw: Any) -> None: super()._copy_internals(**kw) if kw.get("ind_cols_on_fromclause", False): ee = self._Annotated__element # type: ignore -- 2.47.2