]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Add type annotations to sqlalchemy.sql.selectable
authorDzmitar <17720985+dzmitar@users.noreply.github.com>
Tue, 13 Dec 2022 15:59:17 +0000 (16:59 +0100)
committerDzmitar <17720985+dzmitar@users.noreply.github.com>
Tue, 13 Dec 2022 15:59:17 +0000 (16:59 +0100)
lib/sqlalchemy/sql/selectable.py

index fd4157afdd30d44371fc59106e5bb3f8f88f6b80..37e05177bd2909bc5cc297203d588ff79d61b6d4 100644 (file)
@@ -4,7 +4,7 @@
 #
 # This module is part of SQLAlchemy and is released under
 # the MIT License: https://www.opensource.org/licenses/mit-license.php
-# mypy: allow-untyped-defs, allow-untyped-calls
+# TODO_DELETE_mypy: allow-untyped-defs, allow-untyped-calls
 
 """The :class:`_expression.FromClause` class of SQL expression elements,
 representing
@@ -78,6 +78,7 @@ from .base import HasMemoized
 from .base import Immutable
 from .coercions import _document_text_coercion
 from .elements import _anonymous_label
+from .elements import BinaryExpression
 from .elements import BindParameter
 from .elements import BooleanClauseList
 from .elements import ClauseElement
@@ -1292,7 +1293,10 @@ class Join(roles.DMLTableRole, FromClause):
         # set up a special replace function that will replace for
         # ColumnClause with parent table referring to those
         # replaced FromClause objects
-        def replace(obj, **kw):
+        def replace(
+            obj: Union[BinaryExpression[Any], ColumnClause[Any]],
+            **kw: Any,
+        ) -> Optional[KeyedColumnElement[ColumnElement[Any]]]:
             if isinstance(obj, ColumnClause) and obj.table in new_froms:
                 newelem = new_froms[obj.table].corresponding_column(obj)
                 return newelem
@@ -1461,8 +1465,13 @@ class Join(roles.DMLTableRole, FromClause):
 
     @classmethod
     def _joincond_trim_constraints(
-        cls, a, b, constraints, consider_as_foreign_keys
+        cls,
+        a,  # TEMPtype: Union[Table, Join, Subquery]
+        b,  # TEMPtype: Union[Table, Join]
+        constraints,  # type: Dict
+        consider_as_foreign_keys,  # type: Optional[Any]
     ):
+        # type: (...) -> None
         # more than one constraint matched.  narrow down the list
         # to include just those FKCs that match exactly to
         # "consider_as_foreign_keys".
@@ -1550,6 +1559,7 @@ class Join(roles.DMLTableRole, FromClause):
 
 class NoInit:
     def __init__(self, *arg, **kw):
+        # type: (*Any, **Any) -> NoReturn
         raise NotImplementedError(
             "The %s class is not intended to be constructed "
             "directly.  Please use the %s() standalone "
@@ -1605,6 +1615,7 @@ class AliasedReturnsRows(NoInit, NamedFromClause):
         return obj
 
     def _init(self, selectable, name=None):
+        # type: (Any, Optional[str]) -> None
         self.element = coercions.expect(
             roles.ReturnsRowsRole, selectable, apply_propagate_attrs=self
         )
@@ -1626,6 +1637,7 @@ class AliasedReturnsRows(NoInit, NamedFromClause):
         self.element._refresh_for_new_column(column)
 
     def _populate_column_collection(self):
+        # type: () -> None
         self.element._generate_fromclause_column_proxies(self)
 
     @util.ro_non_memoized_property
@@ -2027,15 +2039,16 @@ class CTE(
 
     def _init(
         self,
-        selectable,
-        name=None,
-        recursive=False,
-        nesting=False,
-        _cte_alias=None,
-        _restates=None,
-        _prefixes=None,
-        _suffixes=None,
+        selectable,  # type: Select
+        name=None,  # type: Optional[Any]
+        recursive=False,  # type: bool
+        nesting=False,  # type: bool
+        _cte_alias=None,  # type: Optional[CTE]
+        _restates=None,  # type: Optional[Any]
+        _prefixes=None,  # type: Optional[Tuple[()]]
+        _suffixes=None,  # type: Optional[Tuple[()]]
     ):
+        # type: (...) -> None
         self.recursive = recursive
         self.nesting = nesting
         self._cte_alias = _cte_alias
@@ -2048,6 +2061,7 @@ class CTE(
         super()._init(selectable, name=name)
 
     def _populate_column_collection(self):
+        # type: () -> None
         if self._cte_alias is not None:
             self._cte_alias._generate_fromclause_column_proxies(self)
         else:
@@ -2837,11 +2851,13 @@ class FromGrouping(GroupedElement, FromClause):
         pass
 
     @util.ro_non_memoized_property
-    def columns(self):
+    def columns(self) -> ReadOnlyColumnCollection:
+        # type: () -> ReadOnlyColumnCollection
         return self.element.columns
 
     @util.ro_non_memoized_property
-    def c(self):
+    def c(self) -> ReadOnlyColumnCollection:
+        # type: () -> ReadOnlyColumnCollection
         return self.element.columns
 
     @property
@@ -2849,7 +2865,8 @@ class FromGrouping(GroupedElement, FromClause):
         return self.element.primary_key
 
     @property
-    def foreign_keys(self):
+    def foreign_keys(self) -> Set[ForeignKey]:
+        # type: () -> Set[ForeignKey]
         return self.element.foreign_keys
 
     def is_derived_from(self, fromclause: Optional[FromClause]) -> bool:
@@ -2860,7 +2877,8 @@ class FromGrouping(GroupedElement, FromClause):
     ) -> NamedFromGrouping:
         return NamedFromGrouping(self.element.alias(name=name, flat=flat))
 
-    def _anonymous_fromclause(self, **kw):
+    def _anonymous_fromclause(self, **kw: Any) -> FromGrouping:
+        # type: (**Any) -> FromGrouping
         return FromGrouping(self.element._anonymous_fromclause(**kw))
 
     @util.ro_non_memoized_property
@@ -3467,7 +3485,8 @@ class SelectBase(
         return self._implicit_subquery.select(*arg, **kw)
 
     @HasMemoized.memoized_attribute
-    def _implicit_subquery(self):
+    def _implicit_subquery(self) -> Subquery:
+        # type: () -> Subquery
         return self.subquery()
 
     def _scalar_type(self) -> TypeEngine[Any]:
@@ -3643,10 +3662,12 @@ class SelectStatementGrouping(GroupedElement, SelectBase):
 
     element: SelectBase
 
-    def __init__(self, element):
+    def __init__(self, element) -> None:
+        # type: (Select) -> None
         self.element = coercions.expect(roles.SelectStatementRole, element)
 
-    def _ensure_disambiguated_names(self):
+    def _ensure_disambiguated_names(self) -> SelectStatementGrouping:
+        # type: () -> SelectStatementGrouping
         new_element = self.element._ensure_disambiguated_names()
         if new_element is not self.element:
             return SelectStatementGrouping(new_element)
@@ -4363,7 +4384,8 @@ class CompoundSelect(HasCompileState, GenerativeSelect, ExecutableReturnsRows):
 
         return self
 
-    def _ensure_disambiguated_names(self):
+    def _ensure_disambiguated_names(self) -> CompoundSelect:
+        # type: () -> CompoundSelect
         new_select = self.selects[0]._ensure_disambiguated_names()
         if new_select is not self.selects[0]:
             self = self._generate()
@@ -4417,7 +4439,8 @@ class CompoundSelect(HasCompileState, GenerativeSelect, ExecutableReturnsRows):
             subquery, proxy_compound_columns=extra_col_iterator
         )
 
-    def _refresh_for_new_column(self, column):
+    def _refresh_for_new_column(self, column: ColumnClause) -> None:
+        # type: (ColumnClause) -> None
         super()._refresh_for_new_column(column)
         for select in self.selects:
             select._refresh_for_new_column(column)
@@ -4548,7 +4571,10 @@ class SelectState(util.MemoizedSlots, CompileState):
         pa = prefix_anon_map()
         names = set()
 
-        def go(c, col_name=None):
+        def go(
+            c: Union[ColumnElement[Any], TextClause],
+            col_name: Optional[str] = None,
+        ) -> Optional[str]:
             if c._is_text_clause:
                 return None
 
@@ -5627,7 +5653,10 @@ class Select(
 
         # 3. clone everything else, making sure we use columns
         # corresponding to the froms we just made.
-        def replace(obj, **kw):
+        def replace(
+            obj: Union[BinaryExpression[Any], ColumnClause[Any]],
+            **kw: Any,
+        ) -> Optional[KeyedColumnElement[ColumnElement[Any]]]:
             if isinstance(obj, ColumnClause) and obj.table in new_froms:
                 newelem = new_froms[obj.table].corresponding_column(obj)
                 return newelem
@@ -6240,7 +6269,8 @@ class Select(
         meth = SelectState.get_plugin_class(self).all_selected_columns
         return list(meth(self))
 
-    def _ensure_disambiguated_names(self):
+    def _ensure_disambiguated_names(self) -> Select:
+        # type: () -> Select
         if self._label_style is LABEL_STYLE_NONE:
             self = self.set_label_style(LABEL_STYLE_DISAMBIGUATE_ONLY)
         return self
@@ -6303,7 +6333,8 @@ class Select(
 
         subquery._columns._populate_separate_keys(prox)
 
-    def _needs_parens_for_grouping(self):
+    def _needs_parens_for_grouping(self) -> bool:
+        # type: () -> bool
         return self._has_row_limiting_clause or bool(
             self._order_by_clause.clauses
         )
@@ -6474,11 +6505,13 @@ class ScalarSelect(
 
     element: SelectBase
 
-    def __init__(self, element):
+    def __init__(self, element: Select) -> None:
+        # type: (Select) -> None
         self.element = element
         self.type = element._scalar_type()
 
-    def __getattr__(self, attr):
+    def __getattr__(self, attr: str) -> Callable:
+        # type: (str) -> Callable
         return getattr(self.element, attr)
 
     def __getstate__(self):
@@ -6648,7 +6681,8 @@ class Exists(UnaryExpression[bool]):
     def _from_objects(self) -> List[FromClause]:
         return []
 
-    def _regroup(self, fn):
+    def _regroup(self, fn: Callable) -> SelectStatementGrouping:
+        # type: (Callable) -> SelectStatementGrouping
         element = self.element._ungroup()
         element = fn(element)
         return element.self_group(against=operators.exists)
@@ -6791,7 +6825,13 @@ class TextualSelect(SelectBase, Executable, Generative):
     is_text = True
     is_select = True
 
-    def __init__(self, text, columns, positional=False):
+    def __init__(
+        self,
+        text: TextClause,
+        columns: List[ColumnClause],
+        positional: bool = False,
+    ) -> None:
+        # type: (TextClause, List[ColumnClause], bool) -> None
         self.element = text
         # convert for ORM attributes->columns, etc
         self.column_args = [
@@ -6834,7 +6874,8 @@ class TextualSelect(SelectBase, Executable, Generative):
     def set_label_style(self, style):
         return self
 
-    def _ensure_disambiguated_names(self):
+    def _ensure_disambiguated_names(self) -> TextualSelect:
+        # type: () -> TextualSelect
         return self
 
     @_generative
@@ -6847,8 +6888,9 @@ class TextualSelect(SelectBase, Executable, Generative):
         return self
 
     def _generate_fromclause_column_proxies(
-        self, fromclause, proxy_compound_columns=None
-    ):
+        self, fromclause: Subquery, proxy_compound_columns: zip = None
+    ) -> None:
+        # type: (Subquery, zip) -> None
         if proxy_compound_columns:
             fromclause._columns._populate_separate_keys(
                 c._make_proxy(fromclause, compound_select_cols=extra_cols)
@@ -6870,7 +6912,8 @@ TextAsFrom = TextualSelect
 
 
 class AnnotatedFromClause(Annotated):
-    def _copy_internals(self, **kw):
+    def _copy_internals(self, **kw: Any) -> None:
+        # type: (**Any) -> None
         super()._copy_internals(**kw)
         if kw.get("ind_cols_on_fromclause", False):
             ee = self._Annotated__element  # type: ignore