From b9bd6a40d3e215021cdf11cc124a6cd868d37345 Mon Sep 17 00:00:00 2001 From: Dzmitar <17720985+dzmitar@users.noreply.github.com> Date: Thu, 15 Dec 2022 18:25:02 +0100 Subject: [PATCH] Update type annotations in sqlalchemy.sql.selectable --- lib/sqlalchemy/sql/selectable.py | 177 +++++++++++++++++-------------- 1 file changed, 99 insertions(+), 78 deletions(-) diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py index e91eb4c8ec..60f9a1e642 100644 --- a/lib/sqlalchemy/sql/selectable.py +++ b/lib/sqlalchemy/sql/selectable.py @@ -297,7 +297,7 @@ class Selectable(ReturnsRows): :ref:`tutorial_lateral_correlation` - overview of usage. """ - return Lateral._construct(self, name) + return Lateral._construct(self, name) # type: ignore @util.deprecated( "1.4", @@ -759,7 +759,7 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable): """ - return Alias._construct(self, name) + return Alias._construct(self, name) # type: ignore def tablesample( self, @@ -780,7 +780,7 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable): :func:`_expression.tablesample` - usage guidelines and parameters """ - return TableSample._construct(self, sampling, name, seed) + return TableSample._construct(self, sampling, name, seed) # type: ignore # noqa: E501 def is_derived_from(self, fromclause: Optional[FromClause]) -> bool: """Return ``True`` if this :class:`_expression.FromClause` is @@ -1218,7 +1218,7 @@ class Join(roles.DMLTableRole, FromClause): ).self_group() if onclause is None: - self.onclause = self._match_primaries(self.left, self.right) + self.onclause = self._match_primaries(self.left, self.right) # type: ignore # noqa: E501 else: # note: taken from If91f61527236fd4d7ae3cad1f24c38be921c90ba # not merged yet @@ -1300,6 +1300,7 @@ class Join(roles.DMLTableRole, FromClause): if isinstance(obj, ColumnClause) and obj.table in new_froms: newelem = new_froms[obj.table].corresponding_column(obj) return newelem + return None kw["replace"] = replace @@ -1523,7 +1524,7 @@ class Join(roles.DMLTableRole, FromClause): self, name: Optional[str] = None, flat: bool = False, - ): # WIP -> ??? NamedFromClause: # Union[Join, Subquery]: + ) -> TODO_Any: sqlutil = util.preloaded.sql_util if flat: if name is not None: @@ -1760,13 +1761,13 @@ class TableValuedAlias(LateralFromClause, Alias): ("_render_derived_w_types", InternalTraversal.dp_boolean), ] - def _init( + def _init( # type: ignore self, selectable, name=None, table_value_type=None, joins_implicitly=False, - ): + ) -> None: super()._init(selectable, name=name) self.joins_implicitly = joins_implicitly @@ -1777,7 +1778,7 @@ class TableValuedAlias(LateralFromClause, Alias): ) @HasMemoized.memoized_attribute - def column(self): + def column(self) -> TableValuedColumn[Any]: """Return a column expression representing this :class:`_sql.TableValuedAlias`. @@ -1808,7 +1809,7 @@ class TableValuedAlias(LateralFromClause, Alias): """ - tva = TableValuedAlias._construct( + tva: TableValuedAlias = TableValuedAlias._construct( # type: ignore self, name=name, table_value_type=self._tableval_type, @@ -1834,7 +1835,11 @@ class TableValuedAlias(LateralFromClause, Alias): tva._is_lateral = True return tva - def render_derived(self, name=None, with_types=False): + def render_derived( + self, + name: Optional[str] = None, + with_types: bool = False, + ) -> TableValuedAlias: """Apply "render derived" to this :class:`_sql.TableValuedAlias`. This has the effect of the individual column names listed out @@ -1882,7 +1887,7 @@ class TableValuedAlias(LateralFromClause, Alias): # construct against original to prevent memory growth # for repeated generations - new_alias = TableValuedAlias._construct( + new_alias: TableValuedAlias = TableValuedAlias._construct( # type: ignore # noqa: E501 self.element, name=name, table_value_type=self._tableval_type, @@ -1947,7 +1952,7 @@ class TableSample(FromClauseAlias): __visit_name__ = "tablesample" _traverse_internals: _TraverseInternalsType = ( - AliasedReturnsRows._traverse_internals + AliasedReturnsRows._traverse_internals # type: ignore + [ ("sampling", InternalTraversal.dp_clauseelement), ("seed", InternalTraversal.dp_clauseelement), @@ -1967,16 +1972,22 @@ class TableSample(FromClauseAlias): ) @util.preload_module("sqlalchemy.sql.functions") - def _init(self, selectable, sampling, name=None, seed=None): + def _init( # type: ignore + self, + selectable: Any, + sampling: Union[float, Function[Any]], + name: Optional[str] = None, + seed: Optional[roles.ExpressionElementRole[Any]] = None, + ) -> None: functions = util.preloaded.sql_functions if not isinstance(sampling, functions.Function): sampling = functions.func.system(sampling) - self.sampling = sampling + self.sampling: Function[Any] = sampling self.seed = seed super()._init(selectable, name=name) - def _get_method(self): + def _get_method(self) -> Function[Any]: return self.sampling @@ -2009,7 +2020,7 @@ class CTE( __visit_name__ = "cte" _traverse_internals: _TraverseInternalsType = ( - AliasedReturnsRows._traverse_internals + AliasedReturnsRows._traverse_internals # type: ignore + [ ("_cte_alias", InternalTraversal.dp_clauseelement), ("_restates", InternalTraversal.dp_clauseelement), @@ -2041,14 +2052,14 @@ class CTE( def _init( self, - selectable, - name=None, - recursive=False, - nesting=False, - _cte_alias=None, - _restates=None, - _prefixes=None, - _suffixes=None, + selectable: Select[Any], + name: Optional[str] = None, + recursive: bool = False, + nesting: bool = False, + _cte_alias: Optional[CTE] = None, + _restates: Optional[Any] = None, + _prefixes: Optional[Tuple[()]] = None, + _suffixes: Optional[Tuple[()]] = None, ) -> None: self.recursive = recursive self.nesting = nesting @@ -2081,7 +2092,7 @@ class CTE( :func:`_expression.alias` """ - return CTE._construct( + return CTE._construct( # type: ignore self.element, name=name, recursive=self.recursive, @@ -2110,7 +2121,7 @@ class CTE( self.element ), f"CTE element f{self.element} does not support union()" - return CTE._construct( + return CTE._construct( # type: ignore self.element.union(*other), name=self.name, recursive=self.recursive, @@ -2140,7 +2151,7 @@ class CTE( self.element ), f"CTE element f{self.element} does not support union_all()" - return CTE._construct( + return CTE._construct( # type: ignore self.element.union_all(*other), name=self.name, recursive=self.recursive, @@ -2150,7 +2161,7 @@ class CTE( _suffixes=self._suffixes, ) - def _get_reference_cte(self): + def _get_reference_cte(self) -> Optional[Any]: """ A recursive CTE is updated to attach the recursive part. Updated CTEs should still refer to the original CTE. @@ -2769,7 +2780,7 @@ class HasCTE(roles.HasCTERole, SelectsRows): :meth:`_expression.HasCTE.cte`. """ - return CTE._construct( + return CTE._construct( # type: ignore self, name=name, recursive=recursive, nesting=nesting ) @@ -2831,7 +2842,7 @@ class Subquery(AliasedReturnsRows): "construct before constructing a subquery object, or with the ORM " "use the :meth:`_query.Query.scalar_subquery` method.", ) - def as_scalar(self): + def as_scalar(self) -> ScalarSelect[Any]: return self.element.set_label_style(LABEL_STYLE_NONE).scalar_subquery() @@ -2876,7 +2887,7 @@ class FromGrouping(GroupedElement, FromClause): ) -> NamedFromGrouping: return NamedFromGrouping(self.element.alias(name=name, flat=flat)) - def _anonymous_fromclause(self, **kw: Any): + def _anonymous_fromclause(self, **kw: Any): # type: ignore return FromGrouping(self.element._anonymous_fromclause(**kw)) @util.ro_non_memoized_property @@ -2887,10 +2898,10 @@ class FromGrouping(GroupedElement, FromClause): def _from_objects(self) -> List[FromClause]: return self.element._from_objects - def __getstate__(self): + def __getstate__(self) -> Dict[str, FromClause]: return {"element": self.element} - def __setstate__(self, state): + def __setstate__(self, state: Dict[str, FromClause]) -> None: self.element = state["element"] @@ -3091,7 +3102,7 @@ class ForUpdateArg(ClauseElement): else: return ForUpdateArg(**cast("Dict[str, Any]", with_for_update)) - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: return ( isinstance(other, ForUpdateArg) and other.nowait == self.nowait @@ -3101,10 +3112,10 @@ class ForUpdateArg(ClauseElement): and other.of is self.of ) - def __ne__(self, other): + def __ne__(self, other: Any) -> bool: return not self.__eq__(other) - def __hash__(self): + def __hash__(self) -> int: return id(self) def __init__( @@ -3183,7 +3194,7 @@ class Values(roles.InElementRole, Generative, LateralFromClause): self.named_with_column = not self._unnamed @property - def _column_types(self): + def _column_types(self) -> List[TypeEngine[Any]]: return [col.type for col in self._column_args] @_generative @@ -3310,10 +3321,10 @@ class ScalarValues(roles.InElementRole, GroupedElement, ColumnElement[Any]): self.literal_binds = literal_binds @property - def _column_types(self): + def _column_types(self) -> List[TypeEngine[Any]]: return [col.type for col in self._column_args] - def __clause_element__(self): + def __clause_element__(self) -> ScalarValues: return self @@ -3447,7 +3458,7 @@ class SelectBase( return self._implicit_subquery.columns @property - def columns(self): + def columns(self): # type: ignore return self.c def get_label_style(self) -> SelectLabelStyle: @@ -3479,7 +3490,7 @@ class SelectBase( "first in order to create " "a subquery, which then can be selected.", ) - def select(self, *arg, **kw): + def select(self, *arg: Any, **kw: Any) -> Select[Any]: return self._implicit_subquery.select(*arg, **kw) @HasMemoized.memoized_attribute @@ -3496,7 +3507,7 @@ class SelectBase( "removed in a future release. Please refer to " ":meth:`_expression.SelectBase.scalar_subquery`.", ) - def as_scalar(self): + def as_scalar(self) -> ScalarSelect[Any]: return self.scalar_subquery() def exists(self) -> Exists: @@ -3543,7 +3554,7 @@ class SelectBase( if self._label_style is not LABEL_STYLE_NONE: self = self.set_label_style(LABEL_STYLE_NONE) - return ScalarSelect(self) + return ScalarSelect(self) # type: ignore def label(self, name: Optional[str]) -> Label[Any]: """Return a 'scalar' representation of this selectable, embedded as a @@ -3569,7 +3580,7 @@ class SelectBase( :ref:`tutorial_lateral_correlation` - overview of usage. """ - return Lateral._factory(self, name) + return Lateral._factory(self, name) # type: ignore def subquery(self, name: Optional[str] = None) -> Subquery: """Return a subquery of this :class:`_expression.SelectBase`. @@ -3611,9 +3622,9 @@ class SelectBase( """ - return Subquery._construct(self._ensure_disambiguated_names(), name) + return Subquery._construct(self._ensure_disambiguated_names(), name) # type: ignore # noqa: E501 - def _ensure_disambiguated_names(self): + def _ensure_disambiguated_names(self): # type: ignore """Ensure that the names generated by this selectbase will be disambiguated in some way, if possible. @@ -3663,7 +3674,7 @@ class SelectStatementGrouping(GroupedElement, SelectBase): self.element = coercions.expect(roles.SelectStatementRole, element) def _ensure_disambiguated_names(self) -> SelectStatementGrouping: - new_element = self.element._ensure_disambiguated_names() + new_element = self.element._ensure_disambiguated_names() # type: ignore # noqa: E501 if new_element is not self.element: return SelectStatementGrouping(new_element) else: @@ -3676,8 +3687,9 @@ class SelectStatementGrouping(GroupedElement, SelectBase): self, label_style: SelectLabelStyle ) -> SelectStatementGrouping: return SelectStatementGrouping( - self.element.set_label_style(label_style) - ) + self.element.set_label_style(label_style) # type: ignore + ) # Argument 1 to "SelectStatementGrouping" has incompatible + # type "SelectBase"; expected "Select[Any]" [arg-type]mypy(error) @property def select_statement(self) -> SelectBase: @@ -4249,7 +4261,7 @@ class GenerativeSelect(SelectBase, Generative): @CompileState.plugin_for("default", "compound_select") class CompoundSelectState(CompileState): @util.memoized_property - def _label_resolve_dict(self): + def _label_resolve_dict(self): # type: ignore # TODO: this is hacky and slow hacky_subquery = self.statement.subquery() hacky_subquery.named_with_column = False @@ -4363,7 +4375,7 @@ class CompoundSelect(HasCompileState, GenerativeSelect, ExecutableReturnsRows): def self_group( self, against: Optional[OperatorType] = None ) -> GroupedElement: - return SelectStatementGrouping(self) + return SelectStatementGrouping(self) # type: ignore def is_derived_from(self, fromclause: Optional[FromClause]) -> bool: for s in self.selects: @@ -4380,7 +4392,7 @@ class CompoundSelect(HasCompileState, GenerativeSelect, ExecutableReturnsRows): return self def _ensure_disambiguated_names(self) -> CompoundSelect: - new_select = self.selects[0]._ensure_disambiguated_names() + new_select = self.selects[0]._ensure_disambiguated_names() # type: ignore # noqa: E501 if new_select is not self.selects[0]: self = self._generate() self.selects = [new_select] + self.selects[1:] @@ -4572,26 +4584,26 @@ class SelectState(util.MemoizedSlots, CompileState): return None elif not dedupe: - name = c._proxy_key + name = c._proxy_key # type: ignore if name is None: name = "_no_label" return name - name = c._tq_key_label if table_qualified else c._proxy_key + name = c._tq_key_label if table_qualified else c._proxy_key # type: ignore # noqa: E501 if name is None: name = "_no_label" if name in names: - return c._anon_label(name) % pa + return c._anon_label(name) % pa # type: ignore else: names.add(name) return name elif name in names: return ( - c._anon_tq_key_label % pa + c._anon_tq_key_label % pa # type: ignore if table_qualified - else c._anon_key_label % pa + else c._anon_key_label % pa # type: ignore ) else: names.add(name) @@ -4655,7 +4667,7 @@ class SelectState(util.MemoizedSlots, CompileState): if froms: toremove = set( itertools.chain.from_iterable( - [_expand_cloned(f._hide_froms) for f in froms] + [_expand_cloned(f._hide_froms) for f in froms] # type: ignore # noqa: E501 ) ) if toremove: @@ -4703,8 +4715,8 @@ class SelectState(util.MemoizedSlots, CompileState): f for f in froms if f - not in _cloned_intersection( - _cloned_intersection( + not in _cloned_intersection( # type: ignore + _cloned_intersection( # type: ignore froms, explicit_correlate_froms or () ), to_correlate, @@ -4717,8 +4729,8 @@ class SelectState(util.MemoizedSlots, CompileState): f for f in froms if f - not in _cloned_difference( - _cloned_intersection( + not in _cloned_difference( # type: ignore + _cloned_intersection( # type: ignore froms, explicit_correlate_froms or () ), self.statement._correlate_except, @@ -4735,7 +4747,7 @@ class SelectState(util.MemoizedSlots, CompileState): f for f in froms if f - not in _cloned_intersection(froms, implicit_correlate_froms) + not in _cloned_intersection(froms, implicit_correlate_froms) # type: ignore # noqa: E501 ] if not len(froms): @@ -5172,7 +5184,11 @@ class Select( return self.where(*criteria) - def _filter_by_zero(self): + def _filter_by_zero( + self, + ) -> Union[ + FromClause, _JoinTargetProtocol, ColumnElement[Any], TextClause + ]: if self._setup_joins: meth = SelectState.get_plugin_class( self @@ -5215,7 +5231,7 @@ class Select( from_entity = self._filter_by_zero() clauses = [ - _entity_namespace_key(from_entity, key) == value + _entity_namespace_key(from_entity, key) == value # type: ignore for key, value in kwargs.items() ] return self.filter(*clauses) @@ -5653,6 +5669,7 @@ class Select( if isinstance(obj, ColumnClause) and obj.table in new_froms: newelem = new_froms[obj.table].corresponding_column(obj) return newelem + return None kw["replace"] = replace @@ -5941,8 +5958,8 @@ class Select( _MemoizedSelectEntities._generate_for_statement(self) self._raw_columns = [ - coercions.expect(roles.ColumnsClauseRole, c) - for c in coercions._expression_collection_was_a_list( + coercions.expect(roles.ColumnsClauseRole, c) # type: ignore + for c in coercions._expression_collection_was_a_list( # type: ignore # noqa: E501 "entities", "Select.with_only_columns", entities ) ] @@ -6500,13 +6517,13 @@ class ScalarSelect( self.element = element self.type = element._scalar_type() - def __getattr__(self, attr: str) -> Callable[..., Any]: + def __getattr__(self, attr: str): # type: ignore return getattr(self.element, attr) - def __getstate__(self): + def __getstate__(self) -> Dict[str, Any]: return {"element": self.element, "type": self.type} - def __setstate__(self, state): + def __setstate__(self, state: Dict[str, Any]) -> None: self.element = state["element"] self.type = state["type"] @@ -6670,7 +6687,9 @@ class Exists(UnaryExpression[bool]): def _from_objects(self) -> List[FromClause]: return [] - def _regroup(self, fn: Callable) -> SelectStatementGrouping: + def _regroup( + self, fn: Callable[[Any], ColumnElement[Any]] + ) -> ColumnElement[Any]: element = self.element._ungroup() element = fn(element) return element.self_group(against=operators.exists) @@ -6709,7 +6728,7 @@ class Exists(UnaryExpression[bool]): """ e = self._clone() e.element = self._regroup( - lambda element: element.correlate(*fromclauses) + lambda element: element.correlate(*fromclauses) # type: ignore ) return e @@ -6728,7 +6747,7 @@ class Exists(UnaryExpression[bool]): e = self._clone() e.element = self._regroup( - lambda element: element.correlate_except(*fromclauses) + lambda element: element.correlate_except(*fromclauses) # type: ignore # noqa: E501 ) return e @@ -6746,7 +6765,7 @@ class Exists(UnaryExpression[bool]): """ e = self._clone() - e.element = self._regroup(lambda element: element.select_from(*froms)) + e.element = self._regroup(lambda element: element.select_from(*froms)) # type: ignore # noqa: E501 return e def where( @@ -6764,7 +6783,7 @@ class Exists(UnaryExpression[bool]): """ e = self._clone() - e.element = self._regroup(lambda element: element.where(*clause)) + e.element = self._regroup(lambda element: element.where(*clause)) # type: ignore # noqa: E501 return e @@ -6873,8 +6892,10 @@ class TextualSelect(SelectBase, Executable, Generative): self.element = self.element.bindparams(*binds, **bind_as_values) return self - def _generate_fromclause_column_proxies( - self, fromclause: Subquery, proxy_compound_columns: zip = None + def _generate_fromclause_column_proxies( # type: ignore + self, + fromclause: Subquery, + proxy_compound_columns: Optional[zip[Any]] = None, ) -> None: if proxy_compound_columns: fromclause._columns._populate_separate_keys( @@ -6888,7 +6909,7 @@ class TextualSelect(SelectBase, Executable, Generative): c._make_proxy(fromclause) for c in self.column_args ) - def _scalar_type(self): + def _scalar_type(self) -> Union[TypeEngine[Any], Any]: return self.column_args[0].type -- 2.47.3