From: Dzmitar <17720985+dzmitar@users.noreply.github.com> Date: Wed, 14 Dec 2022 17:47:07 +0000 (+0100) Subject: Update type annotations in sqlalchemy.sql.selectable X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=e708395bbc13e70e37cbad84839e73e8ce7b5a66;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Update type annotations in sqlalchemy.sql.selectable --- diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py index 37e05177bd..76af1ee7ef 100644 --- a/lib/sqlalchemy/sql/selectable.py +++ b/lib/sqlalchemy/sql/selectable.py @@ -140,6 +140,7 @@ if TYPE_CHECKING: from .functions import Function from .schema import ForeignKey from .schema import ForeignKeyConstraint + from .schema import Table from .type_api import TypeEngine from .visitors import _CloneCallableType @@ -1254,7 +1255,7 @@ class Join(roles.DMLTableRole, FromClause): return FromGrouping(self) @util.preload_module("sqlalchemy.sql.util") - def _populate_column_collection(self): + def _populate_column_collection(self) -> None: sqlutil = util.preloaded.sql_util columns: List[KeyedColumnElement[Any]] = [c for c in self.left.c] + [ c for c in self.right.c @@ -1466,12 +1467,11 @@ class Join(roles.DMLTableRole, FromClause): @classmethod def _joincond_trim_constraints( cls, - a, # TEMPtype: Union[Table, Join, Subquery] - b, # TEMPtype: Union[Table, Join] - constraints, # type: Dict - consider_as_foreign_keys, # type: Optional[Any] - ): - # type: (...) -> None + a: Union[Table, Join, Subquery], + b: Union[Table, Join], + constraints: Dict[Any, Any], + consider_as_foreign_keys: Optional[Any], + ) -> None: # more than one constraint matched. narrow down the list # to include just those FKCs that match exactly to # "consider_as_foreign_keys". @@ -1558,8 +1558,7 @@ class Join(roles.DMLTableRole, FromClause): class NoInit: - def __init__(self, *arg, **kw): - # type: (*Any, **Any) -> NoReturn + def __init__(self, *arg: Any, **kw: Any) -> NoReturn: raise NotImplementedError( "The %s class is not intended to be constructed " "directly. Please use the %s() standalone " @@ -1614,8 +1613,7 @@ class AliasedReturnsRows(NoInit, NamedFromClause): obj._init(*arg, **kw) return obj - def _init(self, selectable, name=None): - # type: (Any, Optional[str]) -> None + def _init(self, selectable: Any, name: Optional[str] = None) -> None: self.element = coercions.expect( roles.ReturnsRowsRole, selectable, apply_propagate_attrs=self ) @@ -1636,8 +1634,7 @@ class AliasedReturnsRows(NoInit, NamedFromClause): super()._refresh_for_new_column(column) self.element._refresh_for_new_column(column) - def _populate_column_collection(self): - # type: () -> None + def _populate_column_collection(self) -> None: self.element._generate_fromclause_column_proxies(self) @util.ro_non_memoized_property @@ -2039,16 +2036,15 @@ class CTE( def _init( self, - selectable, # type: Select - name=None, # type: Optional[Any] - recursive=False, # type: bool - nesting=False, # type: bool - _cte_alias=None, # type: Optional[CTE] - _restates=None, # type: Optional[Any] - _prefixes=None, # type: Optional[Tuple[()]] - _suffixes=None, # type: Optional[Tuple[()]] - ): - # type: (...) -> None + selectable, + name=None, + recursive=False, + nesting=False, + _cte_alias=None, + _restates=None, + _prefixes=None, + _suffixes=None, + ) -> None: self.recursive = recursive self.nesting = nesting self._cte_alias = _cte_alias @@ -2060,8 +2056,7 @@ class CTE( self._suffixes = _suffixes super()._init(selectable, name=name) - def _populate_column_collection(self): - # type: () -> None + def _populate_column_collection(self) -> None: if self._cte_alias is not None: self._cte_alias._generate_fromclause_column_proxies(self) else: @@ -2847,26 +2842,25 @@ class FromGrouping(GroupedElement, FromClause): def __init__(self, element: FromClause): self.element = coercions.expect(roles.FromClauseRole, element) - def _init_collections(self): + def _init_collections(self) -> None: pass @util.ro_non_memoized_property - def columns(self) -> ReadOnlyColumnCollection: - # type: () -> ReadOnlyColumnCollection + def columns( + self, + ) -> ReadOnlyColumnCollection[str, KeyedColumnElement[Any]]: return self.element.columns @util.ro_non_memoized_property - def c(self) -> ReadOnlyColumnCollection: - # type: () -> ReadOnlyColumnCollection + def c(self) -> ReadOnlyColumnCollection[str, KeyedColumnElement[Any]]: return self.element.columns @property - def primary_key(self): + def primary_key(self) -> Iterable[NamedColumn[Any]]: return self.element.primary_key @property - def foreign_keys(self) -> Set[ForeignKey]: - # type: () -> Set[ForeignKey] + def foreign_keys(self) -> Iterable[ForeignKey]: return self.element.foreign_keys def is_derived_from(self, fromclause: Optional[FromClause]) -> bool: @@ -2877,8 +2871,7 @@ class FromGrouping(GroupedElement, FromClause): ) -> NamedFromGrouping: return NamedFromGrouping(self.element.alias(name=name, flat=flat)) - def _anonymous_fromclause(self, **kw: Any) -> FromGrouping: - # type: (**Any) -> FromGrouping + def _anonymous_fromclause(self, **kw: Any): return FromGrouping(self.element._anonymous_fromclause(**kw)) @util.ro_non_memoized_property @@ -3445,7 +3438,7 @@ class SelectBase( "from, use the :attr:`_expression.SelectBase.selected_columns` " "attribute.", ) - def c(self): + def c(self) -> ReadOnlyColumnCollection[str, KeyedColumnElement[Any]]: return self._implicit_subquery.columns @property @@ -3486,7 +3479,6 @@ class SelectBase( @HasMemoized.memoized_attribute def _implicit_subquery(self) -> Subquery: - # type: () -> Subquery return self.subquery() def _scalar_type(self) -> TypeEngine[Any]: @@ -3662,12 +3654,10 @@ class SelectStatementGrouping(GroupedElement, SelectBase): element: SelectBase - def __init__(self, element) -> None: - # type: (Select) -> None + def __init__(self, element: Select[Any]) -> None: self.element = coercions.expect(roles.SelectStatementRole, element) def _ensure_disambiguated_names(self) -> SelectStatementGrouping: - # type: () -> SelectStatementGrouping new_element = self.element._ensure_disambiguated_names() if new_element is not self.element: return SelectStatementGrouping(new_element) @@ -3685,7 +3675,7 @@ class SelectStatementGrouping(GroupedElement, SelectBase): ) @property - def select_statement(self): + def select_statement(self) -> SelectBase: return self.element def self_group(self: Self, against: Optional[OperatorType] = None) -> Self: @@ -4376,7 +4366,7 @@ class CompoundSelect(HasCompileState, GenerativeSelect, ExecutableReturnsRows): return True return False - def set_label_style(self, style): + def set_label_style(self, style: SelectLabelStyle) -> CompoundSelect: if self._label_style is not style: self = self._generate() select_0 = self.selects[0].set_label_style(style) @@ -4385,7 +4375,6 @@ class CompoundSelect(HasCompileState, GenerativeSelect, ExecutableReturnsRows): return self def _ensure_disambiguated_names(self) -> CompoundSelect: - # type: () -> CompoundSelect new_select = self.selects[0]._ensure_disambiguated_names() if new_select is not self.selects[0]: self = self._generate() @@ -4439,8 +4428,7 @@ class CompoundSelect(HasCompileState, GenerativeSelect, ExecutableReturnsRows): subquery, proxy_compound_columns=extra_col_iterator ) - def _refresh_for_new_column(self, column: ColumnClause) -> None: - # type: (ColumnClause) -> None + def _refresh_for_new_column(self, column: ColumnElement[Any]) -> None: super()._refresh_for_new_column(column) for select in self.selects: select._refresh_for_new_column(column) @@ -6269,8 +6257,7 @@ class Select( meth = SelectState.get_plugin_class(self).all_selected_columns return list(meth(self)) - def _ensure_disambiguated_names(self) -> Select: - # type: () -> Select + def _ensure_disambiguated_names(self) -> Select[Any]: if self._label_style is LABEL_STYLE_NONE: self = self.set_label_style(LABEL_STYLE_DISAMBIGUATE_ONLY) return self @@ -6334,7 +6321,6 @@ class Select( subquery._columns._populate_separate_keys(prox) def _needs_parens_for_grouping(self) -> bool: - # type: () -> bool return self._has_row_limiting_clause or bool( self._order_by_clause.clauses ) @@ -6505,13 +6491,11 @@ class ScalarSelect( element: SelectBase - def __init__(self, element: Select) -> None: - # type: (Select) -> None + def __init__(self, element: Select[Any]) -> None: self.element = element self.type = element._scalar_type() - def __getattr__(self, attr: str) -> Callable: - # type: (str) -> Callable + def __getattr__(self, attr: str) -> Callable[..., Any]: return getattr(self.element, attr) def __getstate__(self): @@ -6522,7 +6506,7 @@ class ScalarSelect( self.type = state["type"] @property - def columns(self): + def columns(self) -> NoReturn: raise exc.InvalidRequestError( "Scalar Select expression has no " "columns; use this object directly " @@ -6682,7 +6666,6 @@ class Exists(UnaryExpression[bool]): return [] def _regroup(self, fn: Callable) -> SelectStatementGrouping: - # type: (Callable) -> SelectStatementGrouping element = self.element._ungroup() element = fn(element) return element.self_group(against=operators.exists) @@ -6828,10 +6811,9 @@ class TextualSelect(SelectBase, Executable, Generative): def __init__( self, text: TextClause, - columns: List[ColumnClause], + columns: List[ColumnClause[Any]], positional: bool = False, ) -> None: - # type: (TextClause, List[ColumnClause], bool) -> None self.element = text # convert for ORM attributes->columns, etc self.column_args = [ @@ -6871,11 +6853,10 @@ class TextualSelect(SelectBase, Executable, Generative): def _all_selected_columns(self) -> _SelectIterable: return self.column_args - def set_label_style(self, style): + def set_label_style(self, style: SelectLabelStyle) -> TextualSelect: return self def _ensure_disambiguated_names(self) -> TextualSelect: - # type: () -> TextualSelect return self @_generative @@ -6890,7 +6871,6 @@ class TextualSelect(SelectBase, Executable, Generative): def _generate_fromclause_column_proxies( self, fromclause: Subquery, proxy_compound_columns: zip = None ) -> None: - # type: (Subquery, zip) -> None if proxy_compound_columns: fromclause._columns._populate_separate_keys( c._make_proxy(fromclause, compound_select_cols=extra_cols) @@ -6913,7 +6893,6 @@ TextAsFrom = TextualSelect class AnnotatedFromClause(Annotated): def _copy_internals(self, **kw: Any) -> None: - # type: (**Any) -> None super()._copy_internals(**kw) if kw.get("ind_cols_on_fromclause", False): ee = self._Annotated__element # type: ignore