From 7173b047788f8a4230647bfc252037c6e227c708 Mon Sep 17 00:00:00 2001 From: Federico Caselli Date: Sat, 4 May 2024 11:23:52 +0200 Subject: [PATCH] Updated typing for self_group() Fixes: #10939 Closes: #11037 Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/11037 Pull-request-sha: 3ebf4db506ffef629f938f4f36fc76d6671b98e1 Change-Id: I22218286b0dac7bafaaf6955557e25f99a6aefe1 --- lib/sqlalchemy/sql/elements.py | 59 +++++++++++++++++------- lib/sqlalchemy/sql/selectable.py | 35 +++++++------- test/typing/plain_files/sql/operators.py | 5 ++ 3 files changed, 66 insertions(+), 33 deletions(-) diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index 1fadbe19d4..6aecfe203b 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -77,6 +77,7 @@ from .. import util from ..util import HasMemoized_ro_memoized_attribute from ..util import TypingOnly from ..util.typing import Literal +from ..util.typing import ParamSpec from ..util.typing import Self from ..util.typing import TupleAny from ..util.typing import Unpack @@ -1433,13 +1434,11 @@ class ColumnElement( _alt_names: Sequence[str] = () @overload - def self_group( - self: ColumnElement[_T], against: Optional[OperatorType] = None - ) -> ColumnElement[_T]: ... + def self_group(self, against: None = None) -> ColumnElement[_T]: ... @overload def self_group( - self: ColumnElement[Any], against: Optional[OperatorType] = None + self, against: Optional[OperatorType] = None ) -> ColumnElement[Any]: ... def self_group( @@ -2583,7 +2582,9 @@ class TextClause( # be using this method. return self.type.comparator_factory(self) # type: ignore - def self_group(self, against=None): + def self_group( + self, against: Optional[OperatorType] = None + ) -> Union[Self, Grouping[Any]]: if against is operators.in_op: return Grouping(self) else: @@ -2788,7 +2789,9 @@ class ClauseList( def _from_objects(self) -> List[FromClause]: return list(itertools.chain(*[c._from_objects for c in self.clauses])) - def self_group(self, against=None): + def self_group( + self, against: Optional[OperatorType] = None + ) -> Union[Self, Grouping[Any]]: if self.group and operators.is_precedent(self.operator, against): return Grouping(self) else: @@ -2811,7 +2814,9 @@ class OperatorExpression(ColumnElement[_T]): def is_comparison(self): return operators.is_comparison(self.operator) - def self_group(self, against=None): + def self_group( + self, against: Optional[OperatorType] = None + ) -> Union[Self, Grouping[_T]]: if ( self.group and operators.is_precedent(self.operator, against) @@ -3171,7 +3176,9 @@ class BooleanClauseList(ExpressionClauseList[bool]): def _select_iterable(self) -> _SelectIterable: return (self,) - def self_group(self, against=None): + def self_group( + self, against: Optional[OperatorType] = None + ) -> Union[Self, Grouping[bool]]: if not self.clauses: return self else: @@ -3254,7 +3261,7 @@ class Tuple(ClauseList, ColumnElement[TupleAny]): ] ) - def self_group(self, against=None): + def self_group(self, against: Optional[OperatorType] = None) -> Self: # Tuple is parenthesized by definition. return self @@ -3487,7 +3494,9 @@ class TypeCoerce(WrapsColumnExpression[_T]): def wrapped_column_expression(self): return self.clause - def self_group(self, against=None): + def self_group( + self, against: Optional[OperatorType] = None + ) -> TypeCoerce[_T]: grouped = self.clause.self_group(against=against) if grouped is not self.clause: return TypeCoerce(grouped, self.type) @@ -3702,7 +3711,9 @@ class UnaryExpression(ColumnElement[_T]): else: return ClauseElement._negate(self) - def self_group(self, against=None): + def self_group( + self, against: Optional[OperatorType] = None + ) -> Union[Self, Grouping[_T]]: if self.operator and operators.is_precedent(self.operator, against): return Grouping(self) else: @@ -3789,7 +3800,7 @@ class AsBoolean(WrapsColumnExpression[bool], UnaryExpression[bool]): def wrapped_column_expression(self): return self.element - def self_group(self, against=None): + def self_group(self, against: Optional[OperatorType] = None) -> Self: return self def _negate(self): @@ -3989,8 +4000,8 @@ class Slice(ColumnElement[Any]): ) self.type = type_api.NULLTYPE - def self_group(self, against=None): - assert against is operator.getitem + def self_group(self, against: Optional[OperatorType] = None) -> Self: + assert against is operator.getitem # type: ignore[comparison-overlap] return self @@ -4008,7 +4019,7 @@ class GroupedElement(DQLDMLClauseElement): element: ClauseElement - def self_group(self, against=None): + def self_group(self, against: Optional[OperatorType] = None) -> Self: return self def _ungroup(self): @@ -4072,6 +4083,12 @@ class Grouping(GroupedElement, ColumnElement[_T]): self.element = state["element"] self.type = state["type"] + if TYPE_CHECKING: + + def self_group( + self, against: Optional[OperatorType] = None + ) -> Self: ... + class _OverrideBinds(Grouping[_T]): """used by cache_key->_apply_params_to_element to allow compilation / @@ -4572,6 +4589,9 @@ class NamedColumn(KeyedColumnElement[_T]): return c.key, c +_PS = ParamSpec("_PS") + + class Label(roles.LabeledColumnExprRole[_T], NamedColumn[_T]): """Represents a column label (AS). @@ -4669,13 +4689,18 @@ class Label(roles.LabeledColumnExprRole[_T], NamedColumn[_T]): def element(self) -> ColumnElement[_T]: return self._element.self_group(against=operators.as_) - def self_group(self, against=None): + def self_group(self, against: Optional[OperatorType] = None) -> Label[_T]: return self._apply_to_inner(self._element.self_group, against=against) def _negate(self): return self._apply_to_inner(self._element._negate) - def _apply_to_inner(self, fn, *arg, **kw): + def _apply_to_inner( + self, + fn: Callable[_PS, ColumnElement[_T]], + *arg: _PS.args, + **kw: _PS.kwargs, + ) -> Label[_T]: sub_element = fn(*arg, **kw) if sub_element is not self._element: return Label(self.name, sub_element, type_=self.type) diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py index 1727447a2c..4e716e7061 100644 --- a/lib/sqlalchemy/sql/selectable.py +++ b/lib/sqlalchemy/sql/selectable.py @@ -1247,7 +1247,6 @@ class Join(roles.DMLTableRole, FromClause): def self_group( self, against: Optional[OperatorType] = None ) -> FromGrouping: - ... return FromGrouping(self) @util.preload_module("sqlalchemy.sql.util") @@ -2894,6 +2893,12 @@ class FromGrouping(GroupedElement, FromClause): def __setstate__(self, state: Dict[str, FromClause]) -> None: self.element = state["element"] + if TYPE_CHECKING: + + def self_group( + self, against: Optional[OperatorType] = None + ) -> Self: ... + class NamedFromGrouping(FromGrouping, NamedFromClause): """represent a grouping of a named FROM clause @@ -2904,6 +2909,12 @@ class NamedFromGrouping(FromGrouping, NamedFromClause): inherit_cache = True + if TYPE_CHECKING: + + def self_group( + self, against: Optional[OperatorType] = None + ) -> Self: ... + class TableClause(roles.DMLTableRole, Immutable, NamedFromClause): """Represents a minimal "table" construct. @@ -3317,6 +3328,12 @@ class ScalarValues(roles.InElementRole, GroupedElement, ColumnElement[Any]): def __clause_element__(self) -> ScalarValues: return self + if TYPE_CHECKING: + + def self_group( + self, against: Optional[OperatorType] = None + ) -> Self: ... + class SelectBase( roles.SelectStatementRole, @@ -3689,7 +3706,6 @@ class SelectStatementGrouping(GroupedElement, SelectBase, Generic[_SB]): return self.element def self_group(self, against: Optional[OperatorType] = None) -> Self: - ... return self if TYPE_CHECKING: @@ -6344,7 +6360,6 @@ class Select( def self_group( self, against: Optional[OperatorType] = None ) -> Union[SelectStatementGrouping[Self], Self]: - ... """Return a 'grouping' construct as per the :class:`_expression.ClauseElement` specification. @@ -6538,19 +6553,7 @@ class ScalarSelect( ) return self - @overload - def self_group( - self: ScalarSelect[Any], against: Optional[OperatorType] = None - ) -> ScalarSelect[Any]: ... - - @overload - def self_group( - self: ColumnElement[Any], against: Optional[OperatorType] = None - ) -> ColumnElement[Any]: ... - - def self_group( - self, against: Optional[OperatorType] = None - ) -> ColumnElement[Any]: + def self_group(self, against: Optional[OperatorType] = None) -> Self: return self if TYPE_CHECKING: diff --git a/test/typing/plain_files/sql/operators.py b/test/typing/plain_files/sql/operators.py index dbd6f3d48f..d52461d41f 100644 --- a/test/typing/plain_files/sql/operators.py +++ b/test/typing/plain_files/sql/operators.py @@ -154,3 +154,8 @@ reveal_type(op_a1) # op functions t1 = operators.eq(A.id, 1) select().where(t1) + +# EXPECTED_TYPE: BinaryExpression[Any] +reveal_type(col.op("->>")("field")) +# EXPECTED_TYPE: Union[BinaryExpression[Any], Grouping[Any]] +reveal_type(col.op("->>")("field").self_group()) -- 2.47.2