]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Update type annotations in sqlalchemy.sql.selectable
authorDzmitar <17720985+dzmitar@users.noreply.github.com>
Thu, 15 Dec 2022 17:25:02 +0000 (18:25 +0100)
committerDzmitar <17720985+dzmitar@users.noreply.github.com>
Thu, 15 Dec 2022 17:25:02 +0000 (18:25 +0100)
lib/sqlalchemy/sql/selectable.py

index e91eb4c8ecca5d8d3060c98a59e34326eec9c814..60f9a1e6424653da012655b18e6c33a75d8a0e2a 100644 (file)
@@ -297,7 +297,7 @@ class Selectable(ReturnsRows):
             :ref:`tutorial_lateral_correlation` -  overview of usage.
 
         """
-        return Lateral._construct(self, name)
+        return Lateral._construct(self, name)  # type: ignore
 
     @util.deprecated(
         "1.4",
@@ -759,7 +759,7 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable):
 
         """
 
-        return Alias._construct(self, name)
+        return Alias._construct(self, name)  # type: ignore
 
     def tablesample(
         self,
@@ -780,7 +780,7 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable):
             :func:`_expression.tablesample` - usage guidelines and parameters
 
         """
-        return TableSample._construct(self, sampling, name, seed)
+        return TableSample._construct(self, sampling, name, seed)  # type: ignore # noqa: E501
 
     def is_derived_from(self, fromclause: Optional[FromClause]) -> bool:
         """Return ``True`` if this :class:`_expression.FromClause` is
@@ -1218,7 +1218,7 @@ class Join(roles.DMLTableRole, FromClause):
         ).self_group()
 
         if onclause is None:
-            self.onclause = self._match_primaries(self.left, self.right)
+            self.onclause = self._match_primaries(self.left, self.right)  # type: ignore # noqa: E501
         else:
             # note: taken from If91f61527236fd4d7ae3cad1f24c38be921c90ba
             # not merged yet
@@ -1300,6 +1300,7 @@ class Join(roles.DMLTableRole, FromClause):
             if isinstance(obj, ColumnClause) and obj.table in new_froms:
                 newelem = new_froms[obj.table].corresponding_column(obj)
                 return newelem
+            return None
 
         kw["replace"] = replace
 
@@ -1523,7 +1524,7 @@ class Join(roles.DMLTableRole, FromClause):
         self,
         name: Optional[str] = None,
         flat: bool = False,
-    ):  # WIP -> ??? NamedFromClause: # Union[Join, Subquery]:
+    ) -> TODO_Any:
         sqlutil = util.preloaded.sql_util
         if flat:
             if name is not None:
@@ -1760,13 +1761,13 @@ class TableValuedAlias(LateralFromClause, Alias):
         ("_render_derived_w_types", InternalTraversal.dp_boolean),
     ]
 
-    def _init(
+    def _init(  # type: ignore
         self,
         selectable,
         name=None,
         table_value_type=None,
         joins_implicitly=False,
-    ):
+    ) -> None:
         super()._init(selectable, name=name)
 
         self.joins_implicitly = joins_implicitly
@@ -1777,7 +1778,7 @@ class TableValuedAlias(LateralFromClause, Alias):
         )
 
     @HasMemoized.memoized_attribute
-    def column(self):
+    def column(self) -> TableValuedColumn[Any]:
         """Return a column expression representing this
         :class:`_sql.TableValuedAlias`.
 
@@ -1808,7 +1809,7 @@ class TableValuedAlias(LateralFromClause, Alias):
 
         """
 
-        tva = TableValuedAlias._construct(
+        tva: TableValuedAlias = TableValuedAlias._construct(  # type: ignore
             self,
             name=name,
             table_value_type=self._tableval_type,
@@ -1834,7 +1835,11 @@ class TableValuedAlias(LateralFromClause, Alias):
         tva._is_lateral = True
         return tva
 
-    def render_derived(self, name=None, with_types=False):
+    def render_derived(
+        self,
+        name: Optional[str] = None,
+        with_types: bool = False,
+    ) -> TableValuedAlias:
         """Apply "render derived" to this :class:`_sql.TableValuedAlias`.
 
         This has the effect of the individual column names listed out
@@ -1882,7 +1887,7 @@ class TableValuedAlias(LateralFromClause, Alias):
 
         # construct against original to prevent memory growth
         # for repeated generations
-        new_alias = TableValuedAlias._construct(
+        new_alias: TableValuedAlias = TableValuedAlias._construct(  # type: ignore # noqa: E501
             self.element,
             name=name,
             table_value_type=self._tableval_type,
@@ -1947,7 +1952,7 @@ class TableSample(FromClauseAlias):
     __visit_name__ = "tablesample"
 
     _traverse_internals: _TraverseInternalsType = (
-        AliasedReturnsRows._traverse_internals
+        AliasedReturnsRows._traverse_internals  # type: ignore
         + [
             ("sampling", InternalTraversal.dp_clauseelement),
             ("seed", InternalTraversal.dp_clauseelement),
@@ -1967,16 +1972,22 @@ class TableSample(FromClauseAlias):
         )
 
     @util.preload_module("sqlalchemy.sql.functions")
-    def _init(self, selectable, sampling, name=None, seed=None):
+    def _init(  # type: ignore
+        self,
+        selectable: Any,
+        sampling: Union[float, Function[Any]],
+        name: Optional[str] = None,
+        seed: Optional[roles.ExpressionElementRole[Any]] = None,
+    ) -> None:
         functions = util.preloaded.sql_functions
         if not isinstance(sampling, functions.Function):
             sampling = functions.func.system(sampling)
 
-        self.sampling = sampling
+        self.sampling: Function[Any] = sampling
         self.seed = seed
         super()._init(selectable, name=name)
 
-    def _get_method(self):
+    def _get_method(self) -> Function[Any]:
         return self.sampling
 
 
@@ -2009,7 +2020,7 @@ class CTE(
     __visit_name__ = "cte"
 
     _traverse_internals: _TraverseInternalsType = (
-        AliasedReturnsRows._traverse_internals
+        AliasedReturnsRows._traverse_internals  # type: ignore
         + [
             ("_cte_alias", InternalTraversal.dp_clauseelement),
             ("_restates", InternalTraversal.dp_clauseelement),
@@ -2041,14 +2052,14 @@ class CTE(
 
     def _init(
         self,
-        selectable,
-        name=None,
-        recursive=False,
-        nesting=False,
-        _cte_alias=None,
-        _restates=None,
-        _prefixes=None,
-        _suffixes=None,
+        selectable: Select[Any],
+        name: Optional[str] = None,
+        recursive: bool = False,
+        nesting: bool = False,
+        _cte_alias: Optional[CTE] = None,
+        _restates: Optional[Any] = None,
+        _prefixes: Optional[Tuple[()]] = None,
+        _suffixes: Optional[Tuple[()]] = None,
     ) -> None:
         self.recursive = recursive
         self.nesting = nesting
@@ -2081,7 +2092,7 @@ class CTE(
             :func:`_expression.alias`
 
         """
-        return CTE._construct(
+        return CTE._construct(  # type: ignore
             self.element,
             name=name,
             recursive=self.recursive,
@@ -2110,7 +2121,7 @@ class CTE(
             self.element
         ), f"CTE element f{self.element} does not support union()"
 
-        return CTE._construct(
+        return CTE._construct(  # type: ignore
             self.element.union(*other),
             name=self.name,
             recursive=self.recursive,
@@ -2140,7 +2151,7 @@ class CTE(
             self.element
         ), f"CTE element f{self.element} does not support union_all()"
 
-        return CTE._construct(
+        return CTE._construct(  # type: ignore
             self.element.union_all(*other),
             name=self.name,
             recursive=self.recursive,
@@ -2150,7 +2161,7 @@ class CTE(
             _suffixes=self._suffixes,
         )
 
-    def _get_reference_cte(self):
+    def _get_reference_cte(self) -> Optional[Any]:
         """
         A recursive CTE is updated to attach the recursive part.
         Updated CTEs should still refer to the original CTE.
@@ -2769,7 +2780,7 @@ class HasCTE(roles.HasCTERole, SelectsRows):
             :meth:`_expression.HasCTE.cte`.
 
         """
-        return CTE._construct(
+        return CTE._construct(  # type: ignore
             self, name=name, recursive=recursive, nesting=nesting
         )
 
@@ -2831,7 +2842,7 @@ class Subquery(AliasedReturnsRows):
         "construct before constructing a subquery object, or with the ORM "
         "use the :meth:`_query.Query.scalar_subquery` method.",
     )
-    def as_scalar(self):
+    def as_scalar(self) -> ScalarSelect[Any]:
         return self.element.set_label_style(LABEL_STYLE_NONE).scalar_subquery()
 
 
@@ -2876,7 +2887,7 @@ class FromGrouping(GroupedElement, FromClause):
     ) -> NamedFromGrouping:
         return NamedFromGrouping(self.element.alias(name=name, flat=flat))
 
-    def _anonymous_fromclause(self, **kw: Any):
+    def _anonymous_fromclause(self, **kw: Any):  # type: ignore
         return FromGrouping(self.element._anonymous_fromclause(**kw))
 
     @util.ro_non_memoized_property
@@ -2887,10 +2898,10 @@ class FromGrouping(GroupedElement, FromClause):
     def _from_objects(self) -> List[FromClause]:
         return self.element._from_objects
 
-    def __getstate__(self):
+    def __getstate__(self) -> Dict[str, FromClause]:
         return {"element": self.element}
 
-    def __setstate__(self, state):
+    def __setstate__(self, state: Dict[str, FromClause]) -> None:
         self.element = state["element"]
 
 
@@ -3091,7 +3102,7 @@ class ForUpdateArg(ClauseElement):
         else:
             return ForUpdateArg(**cast("Dict[str, Any]", with_for_update))
 
-    def __eq__(self, other):
+    def __eq__(self, other: Any) -> bool:
         return (
             isinstance(other, ForUpdateArg)
             and other.nowait == self.nowait
@@ -3101,10 +3112,10 @@ class ForUpdateArg(ClauseElement):
             and other.of is self.of
         )
 
-    def __ne__(self, other):
+    def __ne__(self, other: Any) -> bool:
         return not self.__eq__(other)
 
-    def __hash__(self):
+    def __hash__(self) -> int:
         return id(self)
 
     def __init__(
@@ -3183,7 +3194,7 @@ class Values(roles.InElementRole, Generative, LateralFromClause):
         self.named_with_column = not self._unnamed
 
     @property
-    def _column_types(self):
+    def _column_types(self) -> List[TypeEngine[Any]]:
         return [col.type for col in self._column_args]
 
     @_generative
@@ -3310,10 +3321,10 @@ class ScalarValues(roles.InElementRole, GroupedElement, ColumnElement[Any]):
         self.literal_binds = literal_binds
 
     @property
-    def _column_types(self):
+    def _column_types(self) -> List[TypeEngine[Any]]:
         return [col.type for col in self._column_args]
 
-    def __clause_element__(self):
+    def __clause_element__(self) -> ScalarValues:
         return self
 
 
@@ -3447,7 +3458,7 @@ class SelectBase(
         return self._implicit_subquery.columns
 
     @property
-    def columns(self):
+    def columns(self):  # type: ignore
         return self.c
 
     def get_label_style(self) -> SelectLabelStyle:
@@ -3479,7 +3490,7 @@ class SelectBase(
         "first in order to create "
         "a subquery, which then can be selected.",
     )
-    def select(self, *arg, **kw):
+    def select(self, *arg: Any, **kw: Any) -> Select[Any]:
         return self._implicit_subquery.select(*arg, **kw)
 
     @HasMemoized.memoized_attribute
@@ -3496,7 +3507,7 @@ class SelectBase(
         "removed in a future release.  Please refer to "
         ":meth:`_expression.SelectBase.scalar_subquery`.",
     )
-    def as_scalar(self):
+    def as_scalar(self) -> ScalarSelect[Any]:
         return self.scalar_subquery()
 
     def exists(self) -> Exists:
@@ -3543,7 +3554,7 @@ class SelectBase(
         if self._label_style is not LABEL_STYLE_NONE:
             self = self.set_label_style(LABEL_STYLE_NONE)
 
-        return ScalarSelect(self)
+        return ScalarSelect(self)  # type: ignore
 
     def label(self, name: Optional[str]) -> Label[Any]:
         """Return a 'scalar' representation of this selectable, embedded as a
@@ -3569,7 +3580,7 @@ class SelectBase(
             :ref:`tutorial_lateral_correlation` -  overview of usage.
 
         """
-        return Lateral._factory(self, name)
+        return Lateral._factory(self, name)  # type: ignore
 
     def subquery(self, name: Optional[str] = None) -> Subquery:
         """Return a subquery of this :class:`_expression.SelectBase`.
@@ -3611,9 +3622,9 @@ class SelectBase(
 
         """
 
-        return Subquery._construct(self._ensure_disambiguated_names(), name)
+        return Subquery._construct(self._ensure_disambiguated_names(), name)  # type: ignore # noqa: E501
 
-    def _ensure_disambiguated_names(self):
+    def _ensure_disambiguated_names(self):  # type: ignore
         """Ensure that the names generated by this selectbase will be
         disambiguated in some way, if possible.
 
@@ -3663,7 +3674,7 @@ class SelectStatementGrouping(GroupedElement, SelectBase):
         self.element = coercions.expect(roles.SelectStatementRole, element)
 
     def _ensure_disambiguated_names(self) -> SelectStatementGrouping:
-        new_element = self.element._ensure_disambiguated_names()
+        new_element = self.element._ensure_disambiguated_names()  # type: ignore # noqa: E501
         if new_element is not self.element:
             return SelectStatementGrouping(new_element)
         else:
@@ -3676,8 +3687,9 @@ class SelectStatementGrouping(GroupedElement, SelectBase):
         self, label_style: SelectLabelStyle
     ) -> SelectStatementGrouping:
         return SelectStatementGrouping(
-            self.element.set_label_style(label_style)
-        )
+            self.element.set_label_style(label_style)  # type: ignore
+        )  # Argument 1 to "SelectStatementGrouping" has incompatible
+        # type "SelectBase"; expected "Select[Any]"  [arg-type]mypy(error)
 
     @property
     def select_statement(self) -> SelectBase:
@@ -4249,7 +4261,7 @@ class GenerativeSelect(SelectBase, Generative):
 @CompileState.plugin_for("default", "compound_select")
 class CompoundSelectState(CompileState):
     @util.memoized_property
-    def _label_resolve_dict(self):
+    def _label_resolve_dict(self):  # type: ignore
         # TODO: this is hacky and slow
         hacky_subquery = self.statement.subquery()
         hacky_subquery.named_with_column = False
@@ -4363,7 +4375,7 @@ class CompoundSelect(HasCompileState, GenerativeSelect, ExecutableReturnsRows):
     def self_group(
         self, against: Optional[OperatorType] = None
     ) -> GroupedElement:
-        return SelectStatementGrouping(self)
+        return SelectStatementGrouping(self)  # type: ignore
 
     def is_derived_from(self, fromclause: Optional[FromClause]) -> bool:
         for s in self.selects:
@@ -4380,7 +4392,7 @@ class CompoundSelect(HasCompileState, GenerativeSelect, ExecutableReturnsRows):
         return self
 
     def _ensure_disambiguated_names(self) -> CompoundSelect:
-        new_select = self.selects[0]._ensure_disambiguated_names()
+        new_select = self.selects[0]._ensure_disambiguated_names()  # type: ignore # noqa: E501
         if new_select is not self.selects[0]:
             self = self._generate()
             self.selects = [new_select] + self.selects[1:]
@@ -4572,26 +4584,26 @@ class SelectState(util.MemoizedSlots, CompileState):
                 return None
 
             elif not dedupe:
-                name = c._proxy_key
+                name = c._proxy_key  # type: ignore
                 if name is None:
                     name = "_no_label"
                 return name
 
-            name = c._tq_key_label if table_qualified else c._proxy_key
+            name = c._tq_key_label if table_qualified else c._proxy_key  # type: ignore # noqa: E501
 
             if name is None:
                 name = "_no_label"
                 if name in names:
-                    return c._anon_label(name) % pa
+                    return c._anon_label(name) % pa  # type: ignore
                 else:
                     names.add(name)
                     return name
 
             elif name in names:
                 return (
-                    c._anon_tq_key_label % pa
+                    c._anon_tq_key_label % pa  # type: ignore
                     if table_qualified
-                    else c._anon_key_label % pa
+                    else c._anon_key_label % pa  # type: ignore
                 )
             else:
                 names.add(name)
@@ -4655,7 +4667,7 @@ class SelectState(util.MemoizedSlots, CompileState):
         if froms:
             toremove = set(
                 itertools.chain.from_iterable(
-                    [_expand_cloned(f._hide_froms) for f in froms]
+                    [_expand_cloned(f._hide_froms) for f in froms]  # type: ignore # noqa: E501
                 )
             )
             if toremove:
@@ -4703,8 +4715,8 @@ class SelectState(util.MemoizedSlots, CompileState):
                     f
                     for f in froms
                     if f
-                    not in _cloned_intersection(
-                        _cloned_intersection(
+                    not in _cloned_intersection(  # type: ignore
+                        _cloned_intersection(  # type: ignore
                             froms, explicit_correlate_froms or ()
                         ),
                         to_correlate,
@@ -4717,8 +4729,8 @@ class SelectState(util.MemoizedSlots, CompileState):
                 f
                 for f in froms
                 if f
-                not in _cloned_difference(
-                    _cloned_intersection(
+                not in _cloned_difference(  # type: ignore
+                    _cloned_intersection(  # type: ignore
                         froms, explicit_correlate_froms or ()
                     ),
                     self.statement._correlate_except,
@@ -4735,7 +4747,7 @@ class SelectState(util.MemoizedSlots, CompileState):
                 f
                 for f in froms
                 if f
-                not in _cloned_intersection(froms, implicit_correlate_froms)
+                not in _cloned_intersection(froms, implicit_correlate_froms)  # type: ignore # noqa: E501
             ]
 
             if not len(froms):
@@ -5172,7 +5184,11 @@ class Select(
 
         return self.where(*criteria)
 
-    def _filter_by_zero(self):
+    def _filter_by_zero(
+        self,
+    ) -> Union[
+        FromClause, _JoinTargetProtocol, ColumnElement[Any], TextClause
+    ]:
         if self._setup_joins:
             meth = SelectState.get_plugin_class(
                 self
@@ -5215,7 +5231,7 @@ class Select(
         from_entity = self._filter_by_zero()
 
         clauses = [
-            _entity_namespace_key(from_entity, key) == value
+            _entity_namespace_key(from_entity, key) == value  # type: ignore
             for key, value in kwargs.items()
         ]
         return self.filter(*clauses)
@@ -5653,6 +5669,7 @@ class Select(
             if isinstance(obj, ColumnClause) and obj.table in new_froms:
                 newelem = new_froms[obj.table].corresponding_column(obj)
                 return newelem
+            return None
 
         kw["replace"] = replace
 
@@ -5941,8 +5958,8 @@ class Select(
         _MemoizedSelectEntities._generate_for_statement(self)
 
         self._raw_columns = [
-            coercions.expect(roles.ColumnsClauseRole, c)
-            for c in coercions._expression_collection_was_a_list(
+            coercions.expect(roles.ColumnsClauseRole, c)  # type: ignore
+            for c in coercions._expression_collection_was_a_list(  # type: ignore # noqa: E501
                 "entities", "Select.with_only_columns", entities
             )
         ]
@@ -6500,13 +6517,13 @@ class ScalarSelect(
         self.element = element
         self.type = element._scalar_type()
 
-    def __getattr__(self, attr: str) -> Callable[..., Any]:
+    def __getattr__(self, attr: str):  # type: ignore
         return getattr(self.element, attr)
 
-    def __getstate__(self):
+    def __getstate__(self) -> Dict[str, Any]:
         return {"element": self.element, "type": self.type}
 
-    def __setstate__(self, state):
+    def __setstate__(self, state: Dict[str, Any]) -> None:
         self.element = state["element"]
         self.type = state["type"]
 
@@ -6670,7 +6687,9 @@ class Exists(UnaryExpression[bool]):
     def _from_objects(self) -> List[FromClause]:
         return []
 
-    def _regroup(self, fn: Callable) -> SelectStatementGrouping:
+    def _regroup(
+        self, fn: Callable[[Any], ColumnElement[Any]]
+    ) -> ColumnElement[Any]:
         element = self.element._ungroup()
         element = fn(element)
         return element.self_group(against=operators.exists)
@@ -6709,7 +6728,7 @@ class Exists(UnaryExpression[bool]):
         """
         e = self._clone()
         e.element = self._regroup(
-            lambda element: element.correlate(*fromclauses)
+            lambda element: element.correlate(*fromclauses)  # type: ignore
         )
         return e
 
@@ -6728,7 +6747,7 @@ class Exists(UnaryExpression[bool]):
 
         e = self._clone()
         e.element = self._regroup(
-            lambda element: element.correlate_except(*fromclauses)
+            lambda element: element.correlate_except(*fromclauses)  # type: ignore # noqa: E501
         )
         return e
 
@@ -6746,7 +6765,7 @@ class Exists(UnaryExpression[bool]):
 
         """
         e = self._clone()
-        e.element = self._regroup(lambda element: element.select_from(*froms))
+        e.element = self._regroup(lambda element: element.select_from(*froms))  # type: ignore # noqa: E501
         return e
 
     def where(
@@ -6764,7 +6783,7 @@ class Exists(UnaryExpression[bool]):
 
         """
         e = self._clone()
-        e.element = self._regroup(lambda element: element.where(*clause))
+        e.element = self._regroup(lambda element: element.where(*clause))  # type: ignore # noqa: E501
         return e
 
 
@@ -6873,8 +6892,10 @@ class TextualSelect(SelectBase, Executable, Generative):
         self.element = self.element.bindparams(*binds, **bind_as_values)
         return self
 
-    def _generate_fromclause_column_proxies(
-        self, fromclause: Subquery, proxy_compound_columns: zip = None
+    def _generate_fromclause_column_proxies(  # type: ignore
+        self,
+        fromclause: Subquery,
+        proxy_compound_columns: Optional[zip[Any]] = None,
     ) -> None:
         if proxy_compound_columns:
             fromclause._columns._populate_separate_keys(
@@ -6888,7 +6909,7 @@ class TextualSelect(SelectBase, Executable, Generative):
                 c._make_proxy(fromclause) for c in self.column_args
             )
 
-    def _scalar_type(self):
+    def _scalar_type(self) -> Union[TypeEngine[Any], Any]:
         return self.column_args[0].type