]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Update type annotations in sqlalchemy.sql.selectable
authorDzmitar <17720985+dzmitar@users.noreply.github.com>
Wed, 14 Dec 2022 17:47:07 +0000 (18:47 +0100)
committerDzmitar <17720985+dzmitar@users.noreply.github.com>
Wed, 14 Dec 2022 17:47:07 +0000 (18:47 +0100)
lib/sqlalchemy/sql/selectable.py

index 37e05177bd2909bc5cc297203d588ff79d61b6d4..76af1ee7ef1e10d3d526440dad3000581ee83941 100644 (file)
@@ -140,6 +140,7 @@ if TYPE_CHECKING:
     from .functions import Function
     from .schema import ForeignKey
     from .schema import ForeignKeyConstraint
+    from .schema import Table
     from .type_api import TypeEngine
     from .visitors import _CloneCallableType
 
@@ -1254,7 +1255,7 @@ class Join(roles.DMLTableRole, FromClause):
         return FromGrouping(self)
 
     @util.preload_module("sqlalchemy.sql.util")
-    def _populate_column_collection(self):
+    def _populate_column_collection(self) -> None:
         sqlutil = util.preloaded.sql_util
         columns: List[KeyedColumnElement[Any]] = [c for c in self.left.c] + [
             c for c in self.right.c
@@ -1466,12 +1467,11 @@ class Join(roles.DMLTableRole, FromClause):
     @classmethod
     def _joincond_trim_constraints(
         cls,
-        a,  # TEMPtype: Union[Table, Join, Subquery]
-        b,  # TEMPtype: Union[Table, Join]
-        constraints,  # type: Dict
-        consider_as_foreign_keys,  # type: Optional[Any]
-    ):
-        # type: (...) -> None
+        a: Union[Table, Join, Subquery],
+        b: Union[Table, Join],
+        constraints: Dict[Any, Any],
+        consider_as_foreign_keys: Optional[Any],
+    ) -> None:
         # more than one constraint matched.  narrow down the list
         # to include just those FKCs that match exactly to
         # "consider_as_foreign_keys".
@@ -1558,8 +1558,7 @@ class Join(roles.DMLTableRole, FromClause):
 
 
 class NoInit:
-    def __init__(self, *arg, **kw):
-        # type: (*Any, **Any) -> NoReturn
+    def __init__(self, *arg: Any, **kw: Any) -> NoReturn:
         raise NotImplementedError(
             "The %s class is not intended to be constructed "
             "directly.  Please use the %s() standalone "
@@ -1614,8 +1613,7 @@ class AliasedReturnsRows(NoInit, NamedFromClause):
         obj._init(*arg, **kw)
         return obj
 
-    def _init(self, selectable, name=None):
-        # type: (Any, Optional[str]) -> None
+    def _init(self, selectable: Any, name: Optional[str] = None) -> None:
         self.element = coercions.expect(
             roles.ReturnsRowsRole, selectable, apply_propagate_attrs=self
         )
@@ -1636,8 +1634,7 @@ class AliasedReturnsRows(NoInit, NamedFromClause):
         super()._refresh_for_new_column(column)
         self.element._refresh_for_new_column(column)
 
-    def _populate_column_collection(self):
-        # type: () -> None
+    def _populate_column_collection(self) -> None:
         self.element._generate_fromclause_column_proxies(self)
 
     @util.ro_non_memoized_property
@@ -2039,16 +2036,15 @@ class CTE(
 
     def _init(
         self,
-        selectable,  # type: Select
-        name=None,  # type: Optional[Any]
-        recursive=False,  # type: bool
-        nesting=False,  # type: bool
-        _cte_alias=None,  # type: Optional[CTE]
-        _restates=None,  # type: Optional[Any]
-        _prefixes=None,  # type: Optional[Tuple[()]]
-        _suffixes=None,  # type: Optional[Tuple[()]]
-    ):
-        # type: (...) -> None
+        selectable,
+        name=None,
+        recursive=False,
+        nesting=False,
+        _cte_alias=None,
+        _restates=None,
+        _prefixes=None,
+        _suffixes=None,
+    ) -> None:
         self.recursive = recursive
         self.nesting = nesting
         self._cte_alias = _cte_alias
@@ -2060,8 +2056,7 @@ class CTE(
             self._suffixes = _suffixes
         super()._init(selectable, name=name)
 
-    def _populate_column_collection(self):
-        # type: () -> None
+    def _populate_column_collection(self) -> None:
         if self._cte_alias is not None:
             self._cte_alias._generate_fromclause_column_proxies(self)
         else:
@@ -2847,26 +2842,25 @@ class FromGrouping(GroupedElement, FromClause):
     def __init__(self, element: FromClause):
         self.element = coercions.expect(roles.FromClauseRole, element)
 
-    def _init_collections(self):
+    def _init_collections(self) -> None:
         pass
 
     @util.ro_non_memoized_property
-    def columns(self) -> ReadOnlyColumnCollection:
-        # type: () -> ReadOnlyColumnCollection
+    def columns(
+        self,
+    ) -> ReadOnlyColumnCollection[str, KeyedColumnElement[Any]]:
         return self.element.columns
 
     @util.ro_non_memoized_property
-    def c(self) -> ReadOnlyColumnCollection:
-        # type: () -> ReadOnlyColumnCollection
+    def c(self) -> ReadOnlyColumnCollection[str, KeyedColumnElement[Any]]:
         return self.element.columns
 
     @property
-    def primary_key(self):
+    def primary_key(self) -> Iterable[NamedColumn[Any]]:
         return self.element.primary_key
 
     @property
-    def foreign_keys(self) -> Set[ForeignKey]:
-        # type: () -> Set[ForeignKey]
+    def foreign_keys(self) -> Iterable[ForeignKey]:
         return self.element.foreign_keys
 
     def is_derived_from(self, fromclause: Optional[FromClause]) -> bool:
@@ -2877,8 +2871,7 @@ class FromGrouping(GroupedElement, FromClause):
     ) -> NamedFromGrouping:
         return NamedFromGrouping(self.element.alias(name=name, flat=flat))
 
-    def _anonymous_fromclause(self, **kw: Any) -> FromGrouping:
-        # type: (**Any) -> FromGrouping
+    def _anonymous_fromclause(self, **kw: Any):
         return FromGrouping(self.element._anonymous_fromclause(**kw))
 
     @util.ro_non_memoized_property
@@ -3445,7 +3438,7 @@ class SelectBase(
         "from, use the :attr:`_expression.SelectBase.selected_columns` "
         "attribute.",
     )
-    def c(self):
+    def c(self) -> ReadOnlyColumnCollection[str, KeyedColumnElement[Any]]:
         return self._implicit_subquery.columns
 
     @property
@@ -3486,7 +3479,6 @@ class SelectBase(
 
     @HasMemoized.memoized_attribute
     def _implicit_subquery(self) -> Subquery:
-        # type: () -> Subquery
         return self.subquery()
 
     def _scalar_type(self) -> TypeEngine[Any]:
@@ -3662,12 +3654,10 @@ class SelectStatementGrouping(GroupedElement, SelectBase):
 
     element: SelectBase
 
-    def __init__(self, element) -> None:
-        # type: (Select) -> None
+    def __init__(self, element: Select[Any]) -> None:
         self.element = coercions.expect(roles.SelectStatementRole, element)
 
     def _ensure_disambiguated_names(self) -> SelectStatementGrouping:
-        # type: () -> SelectStatementGrouping
         new_element = self.element._ensure_disambiguated_names()
         if new_element is not self.element:
             return SelectStatementGrouping(new_element)
@@ -3685,7 +3675,7 @@ class SelectStatementGrouping(GroupedElement, SelectBase):
         )
 
     @property
-    def select_statement(self):
+    def select_statement(self) -> SelectBase:
         return self.element
 
     def self_group(self: Self, against: Optional[OperatorType] = None) -> Self:
@@ -4376,7 +4366,7 @@ class CompoundSelect(HasCompileState, GenerativeSelect, ExecutableReturnsRows):
                 return True
         return False
 
-    def set_label_style(self, style):
+    def set_label_style(self, style: SelectLabelStyle) -> CompoundSelect:
         if self._label_style is not style:
             self = self._generate()
             select_0 = self.selects[0].set_label_style(style)
@@ -4385,7 +4375,6 @@ class CompoundSelect(HasCompileState, GenerativeSelect, ExecutableReturnsRows):
         return self
 
     def _ensure_disambiguated_names(self) -> CompoundSelect:
-        # type: () -> CompoundSelect
         new_select = self.selects[0]._ensure_disambiguated_names()
         if new_select is not self.selects[0]:
             self = self._generate()
@@ -4439,8 +4428,7 @@ class CompoundSelect(HasCompileState, GenerativeSelect, ExecutableReturnsRows):
             subquery, proxy_compound_columns=extra_col_iterator
         )
 
-    def _refresh_for_new_column(self, column: ColumnClause) -> None:
-        # type: (ColumnClause) -> None
+    def _refresh_for_new_column(self, column: ColumnElement[Any]) -> None:
         super()._refresh_for_new_column(column)
         for select in self.selects:
             select._refresh_for_new_column(column)
@@ -6269,8 +6257,7 @@ class Select(
         meth = SelectState.get_plugin_class(self).all_selected_columns
         return list(meth(self))
 
-    def _ensure_disambiguated_names(self) -> Select:
-        # type: () -> Select
+    def _ensure_disambiguated_names(self) -> Select[Any]:
         if self._label_style is LABEL_STYLE_NONE:
             self = self.set_label_style(LABEL_STYLE_DISAMBIGUATE_ONLY)
         return self
@@ -6334,7 +6321,6 @@ class Select(
         subquery._columns._populate_separate_keys(prox)
 
     def _needs_parens_for_grouping(self) -> bool:
-        # type: () -> bool
         return self._has_row_limiting_clause or bool(
             self._order_by_clause.clauses
         )
@@ -6505,13 +6491,11 @@ class ScalarSelect(
 
     element: SelectBase
 
-    def __init__(self, element: Select) -> None:
-        # type: (Select) -> None
+    def __init__(self, element: Select[Any]) -> None:
         self.element = element
         self.type = element._scalar_type()
 
-    def __getattr__(self, attr: str) -> Callable:
-        # type: (str) -> Callable
+    def __getattr__(self, attr: str) -> Callable[..., Any]:
         return getattr(self.element, attr)
 
     def __getstate__(self):
@@ -6522,7 +6506,7 @@ class ScalarSelect(
         self.type = state["type"]
 
     @property
-    def columns(self):
+    def columns(self) -> NoReturn:
         raise exc.InvalidRequestError(
             "Scalar Select expression has no "
             "columns; use this object directly "
@@ -6682,7 +6666,6 @@ class Exists(UnaryExpression[bool]):
         return []
 
     def _regroup(self, fn: Callable) -> SelectStatementGrouping:
-        # type: (Callable) -> SelectStatementGrouping
         element = self.element._ungroup()
         element = fn(element)
         return element.self_group(against=operators.exists)
@@ -6828,10 +6811,9 @@ class TextualSelect(SelectBase, Executable, Generative):
     def __init__(
         self,
         text: TextClause,
-        columns: List[ColumnClause],
+        columns: List[ColumnClause[Any]],
         positional: bool = False,
     ) -> None:
-        # type: (TextClause, List[ColumnClause], bool) -> None
         self.element = text
         # convert for ORM attributes->columns, etc
         self.column_args = [
@@ -6871,11 +6853,10 @@ class TextualSelect(SelectBase, Executable, Generative):
     def _all_selected_columns(self) -> _SelectIterable:
         return self.column_args
 
-    def set_label_style(self, style):
+    def set_label_style(self, style: SelectLabelStyle) -> TextualSelect:
         return self
 
     def _ensure_disambiguated_names(self) -> TextualSelect:
-        # type: () -> TextualSelect
         return self
 
     @_generative
@@ -6890,7 +6871,6 @@ class TextualSelect(SelectBase, Executable, Generative):
     def _generate_fromclause_column_proxies(
         self, fromclause: Subquery, proxy_compound_columns: zip = None
     ) -> None:
-        # type: (Subquery, zip) -> None
         if proxy_compound_columns:
             fromclause._columns._populate_separate_keys(
                 c._make_proxy(fromclause, compound_select_cols=extra_cols)
@@ -6913,7 +6893,6 @@ TextAsFrom = TextualSelect
 
 class AnnotatedFromClause(Annotated):
     def _copy_internals(self, **kw: Any) -> None:
-        # type: (**Any) -> None
         super()._copy_internals(**kw)
         if kw.get("ind_cols_on_fromclause", False):
             ee = self._Annotated__element  # type: ignore