From c28bcf4b490c6e45c38bde56332b6c68bd1f4ea4 Mon Sep 17 00:00:00 2001 From: Martijn Pieters Date: Sat, 18 Nov 2023 16:36:08 -0500 Subject: [PATCH] Add type annotations for Function.filter This includes all methods / properties on the returned FunctionFilter object. This contributes towards #6810 This pull request is: - [x] A documentation / typographical / small typing error fix - Good to go, no issue or tests are needed - [ ] A short code fix - please include the issue number, and create an issue if none exists, which must include a complete example of the issue. one line code fixes without an issue and demonstration will not be accepted. - Please include: `Fixes: #` in the commit message - please include tests. one line code fixes without tests will not be accepted. - [ ] A new feature implementation - please include the issue number, and create an issue if none exists, which must include a complete example of how the feature would look. - Please include: `Fixes: #` in the commit message - please include tests. Closes: #10643 Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/10643 Pull-request-sha: 6137b7b995b6ea0bd4e4195c5693d2312fa26639 Change-Id: I2af1af7617d0cd3fd30b262d36ff982464bac011 (cherry picked from commit 52452ec39d18567126673eeef4cf0dd12039043b) --- lib/sqlalchemy/sql/elements.py | 62 ++++++++++++------- lib/sqlalchemy/sql/functions.py | 14 ++++- lib/sqlalchemy/sql/operators.py | 6 +- .../typing/plain_files/sql/functions_again.py | 6 ++ 4 files changed, 62 insertions(+), 26 deletions(-) diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index cafd291eee..531be31555 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -2694,9 +2694,11 @@ class Null(SingletonConstant, roles.ConstExprRole[None], ColumnElement[None]): _traverse_internals: _TraverseInternalsType = [] _singleton: Null - @util.memoized_property - def type(self): - return type_api.NULLTYPE + if not TYPE_CHECKING: + + @util.memoized_property + def type(self) -> TypeEngine[_T]: # noqa: A001 + return type_api.NULLTYPE @classmethod def _instance(cls) -> Null: @@ -2722,9 +2724,11 @@ class False_( _traverse_internals: _TraverseInternalsType = [] _singleton: False_ - @util.memoized_property - def type(self): - return type_api.BOOLEANTYPE + if not TYPE_CHECKING: + + @util.memoized_property + def type(self) -> TypeEngine[_T]: # noqa: A001 + return type_api.BOOLEANTYPE def _negate(self) -> True_: return True_._singleton @@ -2750,9 +2754,11 @@ class True_(SingletonConstant, roles.ConstExprRole[bool], ColumnElement[bool]): _traverse_internals: _TraverseInternalsType = [] _singleton: True_ - @util.memoized_property - def type(self): - return type_api.BOOLEANTYPE + if not TYPE_CHECKING: + + @util.memoized_property + def type(self) -> TypeEngine[_T]: # noqa: A001 + return type_api.BOOLEANTYPE def _negate(self) -> False_: return False_._singleton @@ -4266,9 +4272,11 @@ class Over(ColumnElement[_T]): return lower, upper - @util.memoized_property - def type(self): - return self.element.type + if not TYPE_CHECKING: + + @util.memoized_property + def type(self) -> TypeEngine[_T]: # noqa: A001 + return self.element.type @util.ro_non_memoized_property def _from_objects(self) -> List[FromClause]: @@ -4341,13 +4349,15 @@ class WithinGroup(ColumnElement[_T]): rows=rows, ) - @util.memoized_property - def type(self): - wgt = self.element.within_group_type(self) - if wgt is not None: - return wgt - else: - return self.element.type + if not TYPE_CHECKING: + + @util.memoized_property + def type(self) -> TypeEngine[_T]: # noqa: A001 + wgt = self.element.within_group_type(self) + if wgt is not None: + return wgt + else: + return self.element.type @util.ro_non_memoized_property def _from_objects(self) -> List[FromClause]: @@ -4397,7 +4407,7 @@ class FunctionFilter(ColumnElement[_T]): self.func = func self.filter(*criterion) - def filter(self, *criterion): + def filter(self, *criterion: _ColumnExpressionArgument[bool]) -> Self: """Produce an additional FILTER against the function. This method adds additional criteria to the initial criteria @@ -4461,15 +4471,19 @@ class FunctionFilter(ColumnElement[_T]): rows=rows, ) - def self_group(self, against=None): + def self_group( + self, against: Optional[OperatorType] = None + ) -> Union[Self, Grouping[_T]]: if operators.is_precedent(operators.filter_op, against): return Grouping(self) else: return self - @util.memoized_property - def type(self): - return self.func.type + if not TYPE_CHECKING: + + @util.memoized_property + def type(self) -> TypeEngine[_T]: # noqa: A001 + return self.func.type @util.ro_non_memoized_property def _from_objects(self) -> List[FromClause]: diff --git a/lib/sqlalchemy/sql/functions.py b/lib/sqlalchemy/sql/functions.py index c5eb6b2811..5b54f46ab7 100644 --- a/lib/sqlalchemy/sql/functions.py +++ b/lib/sqlalchemy/sql/functions.py @@ -62,7 +62,6 @@ from .sqltypes import TableValueType from .type_api import TypeEngine from .visitors import InternalTraversal from .. import util -from ..util.typing import Self if TYPE_CHECKING: @@ -79,6 +78,7 @@ if TYPE_CHECKING: from ..engine.cursor import CursorResult from ..engine.interfaces import _CoreMultiExecuteParams from ..engine.interfaces import CoreExecuteOptionsParameter + from ..util.typing import Self _T = TypeVar("_T", bound=Any) _S = TypeVar("_S", bound=Any) @@ -484,6 +484,18 @@ class FunctionElement(Executable, ColumnElement[_T], FromClause, Generative): """ return WithinGroup(self, *order_by) + @overload + def filter(self) -> Self: + ... + + @overload + def filter( + self, + __criterion0: _ColumnExpressionArgument[bool], + *criterion: _ColumnExpressionArgument[bool], + ) -> FunctionFilter[_T]: + ... + def filter( self, *criterion: _ColumnExpressionArgument[bool] ) -> Union[Self, FunctionFilter[_T]]: diff --git a/lib/sqlalchemy/sql/operators.py b/lib/sqlalchemy/sql/operators.py index 6402d0fd1b..1d3f2f483f 100644 --- a/lib/sqlalchemy/sql/operators.py +++ b/lib/sqlalchemy/sql/operators.py @@ -2582,9 +2582,13 @@ _PRECEDENCE: Dict[OperatorType, int] = { } -def is_precedent(operator: OperatorType, against: OperatorType) -> bool: +def is_precedent( + operator: OperatorType, against: Optional[OperatorType] +) -> bool: if operator is against and is_natural_self_precedent(operator): return False + elif against is None: + return True else: return bool( _PRECEDENCE.get( diff --git a/test/typing/plain_files/sql/functions_again.py b/test/typing/plain_files/sql/functions_again.py index edfbd6bb2b..5173d1fe08 100644 --- a/test/typing/plain_files/sql/functions_again.py +++ b/test/typing/plain_files/sql/functions_again.py @@ -21,3 +21,9 @@ func.row_number().over(order_by=[Foo.a.desc(), Foo.b.desc()]) func.row_number().over(partition_by=[Foo.a.desc(), Foo.b.desc()]) func.row_number().over(order_by="a", partition_by=("a", "b")) func.row_number().over(partition_by="a", order_by=("a", "b")) + + +# EXPECTED_TYPE: Function[Any] +reveal_type(func.row_number().filter()) +# EXPECTED_TYPE: FunctionFilter[Any] +reveal_type(func.row_number().filter(Foo.a > 0)) -- 2.47.2