From: Dzmitar <17720985+dzmitar@users.noreply.github.com> Date: Tue, 13 Dec 2022 15:59:17 +0000 (+0100) Subject: Add type annotations to sqlalchemy.sql.selectable X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=a7c85403c98fac25259ad3845d32d9dc7c4a65e9;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Add type annotations to sqlalchemy.sql.selectable --- diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py index fd4157afdd..37e05177bd 100644 --- a/lib/sqlalchemy/sql/selectable.py +++ b/lib/sqlalchemy/sql/selectable.py @@ -4,7 +4,7 @@ # # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php -# mypy: allow-untyped-defs, allow-untyped-calls +# TODO_DELETE_mypy: allow-untyped-defs, allow-untyped-calls """The :class:`_expression.FromClause` class of SQL expression elements, representing @@ -78,6 +78,7 @@ from .base import HasMemoized from .base import Immutable from .coercions import _document_text_coercion from .elements import _anonymous_label +from .elements import BinaryExpression from .elements import BindParameter from .elements import BooleanClauseList from .elements import ClauseElement @@ -1292,7 +1293,10 @@ class Join(roles.DMLTableRole, FromClause): # set up a special replace function that will replace for # ColumnClause with parent table referring to those # replaced FromClause objects - def replace(obj, **kw): + def replace( + obj: Union[BinaryExpression[Any], ColumnClause[Any]], + **kw: Any, + ) -> Optional[KeyedColumnElement[ColumnElement[Any]]]: if isinstance(obj, ColumnClause) and obj.table in new_froms: newelem = new_froms[obj.table].corresponding_column(obj) return newelem @@ -1461,8 +1465,13 @@ class Join(roles.DMLTableRole, FromClause): @classmethod def _joincond_trim_constraints( - cls, a, b, constraints, consider_as_foreign_keys + cls, + a, # TEMPtype: Union[Table, Join, Subquery] + b, # TEMPtype: Union[Table, Join] + constraints, # type: Dict + consider_as_foreign_keys, # type: Optional[Any] ): + # type: (...) -> None # more than one constraint matched. narrow down the list # to include just those FKCs that match exactly to # "consider_as_foreign_keys". @@ -1550,6 +1559,7 @@ class Join(roles.DMLTableRole, FromClause): class NoInit: def __init__(self, *arg, **kw): + # type: (*Any, **Any) -> NoReturn raise NotImplementedError( "The %s class is not intended to be constructed " "directly. Please use the %s() standalone " @@ -1605,6 +1615,7 @@ class AliasedReturnsRows(NoInit, NamedFromClause): return obj def _init(self, selectable, name=None): + # type: (Any, Optional[str]) -> None self.element = coercions.expect( roles.ReturnsRowsRole, selectable, apply_propagate_attrs=self ) @@ -1626,6 +1637,7 @@ class AliasedReturnsRows(NoInit, NamedFromClause): self.element._refresh_for_new_column(column) def _populate_column_collection(self): + # type: () -> None self.element._generate_fromclause_column_proxies(self) @util.ro_non_memoized_property @@ -2027,15 +2039,16 @@ class CTE( def _init( self, - selectable, - name=None, - recursive=False, - nesting=False, - _cte_alias=None, - _restates=None, - _prefixes=None, - _suffixes=None, + selectable, # type: Select + name=None, # type: Optional[Any] + recursive=False, # type: bool + nesting=False, # type: bool + _cte_alias=None, # type: Optional[CTE] + _restates=None, # type: Optional[Any] + _prefixes=None, # type: Optional[Tuple[()]] + _suffixes=None, # type: Optional[Tuple[()]] ): + # type: (...) -> None self.recursive = recursive self.nesting = nesting self._cte_alias = _cte_alias @@ -2048,6 +2061,7 @@ class CTE( super()._init(selectable, name=name) def _populate_column_collection(self): + # type: () -> None if self._cte_alias is not None: self._cte_alias._generate_fromclause_column_proxies(self) else: @@ -2837,11 +2851,13 @@ class FromGrouping(GroupedElement, FromClause): pass @util.ro_non_memoized_property - def columns(self): + def columns(self) -> ReadOnlyColumnCollection: + # type: () -> ReadOnlyColumnCollection return self.element.columns @util.ro_non_memoized_property - def c(self): + def c(self) -> ReadOnlyColumnCollection: + # type: () -> ReadOnlyColumnCollection return self.element.columns @property @@ -2849,7 +2865,8 @@ class FromGrouping(GroupedElement, FromClause): return self.element.primary_key @property - def foreign_keys(self): + def foreign_keys(self) -> Set[ForeignKey]: + # type: () -> Set[ForeignKey] return self.element.foreign_keys def is_derived_from(self, fromclause: Optional[FromClause]) -> bool: @@ -2860,7 +2877,8 @@ class FromGrouping(GroupedElement, FromClause): ) -> NamedFromGrouping: return NamedFromGrouping(self.element.alias(name=name, flat=flat)) - def _anonymous_fromclause(self, **kw): + def _anonymous_fromclause(self, **kw: Any) -> FromGrouping: + # type: (**Any) -> FromGrouping return FromGrouping(self.element._anonymous_fromclause(**kw)) @util.ro_non_memoized_property @@ -3467,7 +3485,8 @@ class SelectBase( return self._implicit_subquery.select(*arg, **kw) @HasMemoized.memoized_attribute - def _implicit_subquery(self): + def _implicit_subquery(self) -> Subquery: + # type: () -> Subquery return self.subquery() def _scalar_type(self) -> TypeEngine[Any]: @@ -3643,10 +3662,12 @@ class SelectStatementGrouping(GroupedElement, SelectBase): element: SelectBase - def __init__(self, element): + def __init__(self, element) -> None: + # type: (Select) -> None self.element = coercions.expect(roles.SelectStatementRole, element) - def _ensure_disambiguated_names(self): + def _ensure_disambiguated_names(self) -> SelectStatementGrouping: + # type: () -> SelectStatementGrouping new_element = self.element._ensure_disambiguated_names() if new_element is not self.element: return SelectStatementGrouping(new_element) @@ -4363,7 +4384,8 @@ class CompoundSelect(HasCompileState, GenerativeSelect, ExecutableReturnsRows): return self - def _ensure_disambiguated_names(self): + def _ensure_disambiguated_names(self) -> CompoundSelect: + # type: () -> CompoundSelect new_select = self.selects[0]._ensure_disambiguated_names() if new_select is not self.selects[0]: self = self._generate() @@ -4417,7 +4439,8 @@ class CompoundSelect(HasCompileState, GenerativeSelect, ExecutableReturnsRows): subquery, proxy_compound_columns=extra_col_iterator ) - def _refresh_for_new_column(self, column): + def _refresh_for_new_column(self, column: ColumnClause) -> None: + # type: (ColumnClause) -> None super()._refresh_for_new_column(column) for select in self.selects: select._refresh_for_new_column(column) @@ -4548,7 +4571,10 @@ class SelectState(util.MemoizedSlots, CompileState): pa = prefix_anon_map() names = set() - def go(c, col_name=None): + def go( + c: Union[ColumnElement[Any], TextClause], + col_name: Optional[str] = None, + ) -> Optional[str]: if c._is_text_clause: return None @@ -5627,7 +5653,10 @@ class Select( # 3. clone everything else, making sure we use columns # corresponding to the froms we just made. - def replace(obj, **kw): + def replace( + obj: Union[BinaryExpression[Any], ColumnClause[Any]], + **kw: Any, + ) -> Optional[KeyedColumnElement[ColumnElement[Any]]]: if isinstance(obj, ColumnClause) and obj.table in new_froms: newelem = new_froms[obj.table].corresponding_column(obj) return newelem @@ -6240,7 +6269,8 @@ class Select( meth = SelectState.get_plugin_class(self).all_selected_columns return list(meth(self)) - def _ensure_disambiguated_names(self): + def _ensure_disambiguated_names(self) -> Select: + # type: () -> Select if self._label_style is LABEL_STYLE_NONE: self = self.set_label_style(LABEL_STYLE_DISAMBIGUATE_ONLY) return self @@ -6303,7 +6333,8 @@ class Select( subquery._columns._populate_separate_keys(prox) - def _needs_parens_for_grouping(self): + def _needs_parens_for_grouping(self) -> bool: + # type: () -> bool return self._has_row_limiting_clause or bool( self._order_by_clause.clauses ) @@ -6474,11 +6505,13 @@ class ScalarSelect( element: SelectBase - def __init__(self, element): + def __init__(self, element: Select) -> None: + # type: (Select) -> None self.element = element self.type = element._scalar_type() - def __getattr__(self, attr): + def __getattr__(self, attr: str) -> Callable: + # type: (str) -> Callable return getattr(self.element, attr) def __getstate__(self): @@ -6648,7 +6681,8 @@ class Exists(UnaryExpression[bool]): def _from_objects(self) -> List[FromClause]: return [] - def _regroup(self, fn): + def _regroup(self, fn: Callable) -> SelectStatementGrouping: + # type: (Callable) -> SelectStatementGrouping element = self.element._ungroup() element = fn(element) return element.self_group(against=operators.exists) @@ -6791,7 +6825,13 @@ class TextualSelect(SelectBase, Executable, Generative): is_text = True is_select = True - def __init__(self, text, columns, positional=False): + def __init__( + self, + text: TextClause, + columns: List[ColumnClause], + positional: bool = False, + ) -> None: + # type: (TextClause, List[ColumnClause], bool) -> None self.element = text # convert for ORM attributes->columns, etc self.column_args = [ @@ -6834,7 +6874,8 @@ class TextualSelect(SelectBase, Executable, Generative): def set_label_style(self, style): return self - def _ensure_disambiguated_names(self): + def _ensure_disambiguated_names(self) -> TextualSelect: + # type: () -> TextualSelect return self @_generative @@ -6847,8 +6888,9 @@ class TextualSelect(SelectBase, Executable, Generative): return self def _generate_fromclause_column_proxies( - self, fromclause, proxy_compound_columns=None - ): + self, fromclause: Subquery, proxy_compound_columns: zip = None + ) -> None: + # type: (Subquery, zip) -> None if proxy_compound_columns: fromclause._columns._populate_separate_keys( c._make_proxy(fromclause, compound_select_cols=extra_cols) @@ -6870,7 +6912,8 @@ TextAsFrom = TextualSelect class AnnotatedFromClause(Annotated): - def _copy_internals(self, **kw): + def _copy_internals(self, **kw: Any) -> None: + # type: (**Any) -> None super()._copy_internals(**kw) if kw.get("ind_cols_on_fromclause", False): ee = self._Annotated__element # type: ignore