]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Updated typing for self_group()
authorFederico Caselli <cfederico87@gmail.com>
Sat, 4 May 2024 09:23:52 +0000 (11:23 +0200)
committerFederico Caselli <cfederico87@gmail.com>
Sat, 4 May 2024 11:17:29 +0000 (13:17 +0200)
Fixes: #10939
Closes: #11037
Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/11037
Pull-request-sha: 3ebf4db506ffef629f938f4f36fc76d6671b98e1

Change-Id: I22218286b0dac7bafaaf6955557e25f99a6aefe1
(cherry picked from commit 7173b047788f8a4230647bfc252037c6e227c708)

lib/sqlalchemy/sql/elements.py
lib/sqlalchemy/sql/selectable.py
test/typing/plain_files/sql/operators.py

index 24f04fd76702dfff804d050bc053456e1509bc2a..0d75318296996ddb4d5629152d6826bb2f930148 100644 (file)
@@ -77,6 +77,7 @@ from .. import util
 from ..util import HasMemoized_ro_memoized_attribute
 from ..util import TypingOnly
 from ..util.typing import Literal
+from ..util.typing import ParamSpec
 from ..util.typing import Self
 
 if typing.TYPE_CHECKING:
@@ -1429,13 +1430,11 @@ class ColumnElement(
     _alt_names: Sequence[str] = ()
 
     @overload
-    def self_group(
-        self: ColumnElement[_T], against: Optional[OperatorType] = None
-    ) -> ColumnElement[_T]: ...
+    def self_group(self, against: None = None) -> ColumnElement[_T]: ...
 
     @overload
     def self_group(
-        self: ColumnElement[Any], against: Optional[OperatorType] = None
+        self, against: Optional[OperatorType] = None
     ) -> ColumnElement[Any]: ...
 
     def self_group(
@@ -2581,7 +2580,9 @@ class TextClause(
         # be using this method.
         return self.type.comparator_factory(self)  # type: ignore
 
-    def self_group(self, against=None):
+    def self_group(
+        self, against: Optional[OperatorType] = None
+    ) -> Union[Self, Grouping[Any]]:
         if against is operators.in_op:
             return Grouping(self)
         else:
@@ -2786,7 +2787,9 @@ class ClauseList(
     def _from_objects(self) -> List[FromClause]:
         return list(itertools.chain(*[c._from_objects for c in self.clauses]))
 
-    def self_group(self, against=None):
+    def self_group(
+        self, against: Optional[OperatorType] = None
+    ) -> Union[Self, Grouping[Any]]:
         if self.group and operators.is_precedent(self.operator, against):
             return Grouping(self)
         else:
@@ -2809,7 +2812,9 @@ class OperatorExpression(ColumnElement[_T]):
     def is_comparison(self):
         return operators.is_comparison(self.operator)
 
-    def self_group(self, against=None):
+    def self_group(
+        self, against: Optional[OperatorType] = None
+    ) -> Union[Self, Grouping[_T]]:
         if (
             self.group
             and operators.is_precedent(self.operator, against)
@@ -3169,7 +3174,9 @@ class BooleanClauseList(ExpressionClauseList[bool]):
     def _select_iterable(self) -> _SelectIterable:
         return (self,)
 
-    def self_group(self, against=None):
+    def self_group(
+        self, against: Optional[OperatorType] = None
+    ) -> Union[Self, Grouping[bool]]:
         if not self.clauses:
             return self
         else:
@@ -3252,7 +3259,7 @@ class Tuple(ClauseList, ColumnElement[typing_Tuple[Any, ...]]):
                 ]
             )
 
-    def self_group(self, against=None):
+    def self_group(self, against: Optional[OperatorType] = None) -> Self:
         # Tuple is parenthesized by definition.
         return self
 
@@ -3485,7 +3492,9 @@ class TypeCoerce(WrapsColumnExpression[_T]):
     def wrapped_column_expression(self):
         return self.clause
 
-    def self_group(self, against=None):
+    def self_group(
+        self, against: Optional[OperatorType] = None
+    ) -> TypeCoerce[_T]:
         grouped = self.clause.self_group(against=against)
         if grouped is not self.clause:
             return TypeCoerce(grouped, self.type)
@@ -3700,7 +3709,9 @@ class UnaryExpression(ColumnElement[_T]):
         else:
             return ClauseElement._negate(self)
 
-    def self_group(self, against=None):
+    def self_group(
+        self, against: Optional[OperatorType] = None
+    ) -> Union[Self, Grouping[_T]]:
         if self.operator and operators.is_precedent(self.operator, against):
             return Grouping(self)
         else:
@@ -3787,7 +3798,7 @@ class AsBoolean(WrapsColumnExpression[bool], UnaryExpression[bool]):
     def wrapped_column_expression(self):
         return self.element
 
-    def self_group(self, against=None):
+    def self_group(self, against: Optional[OperatorType] = None) -> Self:
         return self
 
     def _negate(self):
@@ -3987,8 +3998,8 @@ class Slice(ColumnElement[Any]):
         )
         self.type = type_api.NULLTYPE
 
-    def self_group(self, against=None):
-        assert against is operator.getitem
+    def self_group(self, against: Optional[OperatorType] = None) -> Self:
+        assert against is operator.getitem  # type: ignore[comparison-overlap]
         return self
 
 
@@ -4006,7 +4017,7 @@ class GroupedElement(DQLDMLClauseElement):
 
     element: ClauseElement
 
-    def self_group(self, against=None):
+    def self_group(self, against: Optional[OperatorType] = None) -> Self:
         return self
 
     def _ungroup(self):
@@ -4070,6 +4081,12 @@ class Grouping(GroupedElement, ColumnElement[_T]):
         self.element = state["element"]
         self.type = state["type"]
 
+    if TYPE_CHECKING:
+
+        def self_group(
+            self, against: Optional[OperatorType] = None
+        ) -> Self: ...
+
 
 class _OverrideBinds(Grouping[_T]):
     """used by cache_key->_apply_params_to_element to allow compilation /
@@ -4570,6 +4587,9 @@ class NamedColumn(KeyedColumnElement[_T]):
         return c.key, c
 
 
+_PS = ParamSpec("_PS")
+
+
 class Label(roles.LabeledColumnExprRole[_T], NamedColumn[_T]):
     """Represents a column label (AS).
 
@@ -4667,13 +4687,18 @@ class Label(roles.LabeledColumnExprRole[_T], NamedColumn[_T]):
     def element(self) -> ColumnElement[_T]:
         return self._element.self_group(against=operators.as_)
 
-    def self_group(self, against=None):
+    def self_group(self, against: Optional[OperatorType] = None) -> Label[_T]:
         return self._apply_to_inner(self._element.self_group, against=against)
 
     def _negate(self):
         return self._apply_to_inner(self._element._negate)
 
-    def _apply_to_inner(self, fn, *arg, **kw):
+    def _apply_to_inner(
+        self,
+        fn: Callable[_PS, ColumnElement[_T]],
+        *arg: _PS.args,
+        **kw: _PS.kwargs,
+    ) -> Label[_T]:
         sub_element = fn(*arg, **kw)
         if sub_element is not self._element:
             return Label(self.name, sub_element, type_=self.type)
index f33e0a41fb77bb0c6bd73e60c702397f38e5e002..143d67b58d33d67eeedcda8954a4b1ab0102f92d 100644 (file)
@@ -1242,7 +1242,6 @@ class Join(roles.DMLTableRole, FromClause):
     def self_group(
         self, against: Optional[OperatorType] = None
     ) -> FromGrouping:
-        ...
         return FromGrouping(self)
 
     @util.preload_module("sqlalchemy.sql.util")
@@ -2889,6 +2888,12 @@ class FromGrouping(GroupedElement, FromClause):
     def __setstate__(self, state: Dict[str, FromClause]) -> None:
         self.element = state["element"]
 
+    if TYPE_CHECKING:
+
+        def self_group(
+            self, against: Optional[OperatorType] = None
+        ) -> Self: ...
+
 
 class NamedFromGrouping(FromGrouping, NamedFromClause):
     """represent a grouping of a named FROM clause
@@ -2899,6 +2904,12 @@ class NamedFromGrouping(FromGrouping, NamedFromClause):
 
     inherit_cache = True
 
+    if TYPE_CHECKING:
+
+        def self_group(
+            self, against: Optional[OperatorType] = None
+        ) -> Self: ...
+
 
 class TableClause(roles.DMLTableRole, Immutable, NamedFromClause):
     """Represents a minimal "table" construct.
@@ -3312,6 +3323,12 @@ class ScalarValues(roles.InElementRole, GroupedElement, ColumnElement[Any]):
     def __clause_element__(self) -> ScalarValues:
         return self
 
+    if TYPE_CHECKING:
+
+        def self_group(
+            self, against: Optional[OperatorType] = None
+        ) -> Self: ...
+
 
 class SelectBase(
     roles.SelectStatementRole,
@@ -3684,7 +3701,6 @@ class SelectStatementGrouping(GroupedElement, SelectBase, Generic[_SB]):
         return self.element
 
     def self_group(self, against: Optional[OperatorType] = None) -> Self:
-        ...
         return self
 
     if TYPE_CHECKING:
@@ -6325,7 +6341,6 @@ class Select(
     def self_group(
         self, against: Optional[OperatorType] = None
     ) -> Union[SelectStatementGrouping[Self], Self]:
-        ...
         """Return a 'grouping' construct as per the
         :class:`_expression.ClauseElement` specification.
 
@@ -6517,19 +6532,7 @@ class ScalarSelect(
         self.element = cast("Select[Any]", self.element).where(crit)
         return self
 
-    @overload
-    def self_group(
-        self: ScalarSelect[Any], against: Optional[OperatorType] = None
-    ) -> ScalarSelect[Any]: ...
-
-    @overload
-    def self_group(
-        self: ColumnElement[Any], against: Optional[OperatorType] = None
-    ) -> ColumnElement[Any]: ...
-
-    def self_group(
-        self, against: Optional[OperatorType] = None
-    ) -> ColumnElement[Any]:
+    def self_group(self, against: Optional[OperatorType] = None) -> Self:
         return self
 
     if TYPE_CHECKING:
index dbd6f3d48f42f5b0fddff5a30d8fee418f687980..d52461d41f11e4a387e7419c9320a492b5ba310d 100644 (file)
@@ -154,3 +154,8 @@ reveal_type(op_a1)
 # op functions
 t1 = operators.eq(A.id, 1)
 select().where(t1)
+
+# EXPECTED_TYPE: BinaryExpression[Any]
+reveal_type(col.op("->>")("field"))
+# EXPECTED_TYPE: Union[BinaryExpression[Any], Grouping[Any]]
+reveal_type(col.op("->>")("field").self_group())