From: Martijn Pieters Date: Wed, 15 Nov 2023 19:04:48 +0000 (+0000) Subject: Add type annotations for Function.filter X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=2fb5778579abc90f3189c9631a8f0861a3930bbf;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Add type annotations for Function.filter This includes all methods / properties on the returned FunctionFilter object. --- diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index 49505168c0..0ac66f37cf 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -4408,7 +4408,7 @@ class FunctionFilter(ColumnElement[_T]): self.func = func self.filter(*criterion) - def filter(self, *criterion): + def filter(self, *criterion: _ColumnExpressionArgument[bool]) -> Self: """Produce an additional FILTER against the function. This method adds additional criteria to the initial criteria @@ -4472,15 +4472,21 @@ class FunctionFilter(ColumnElement[_T]): rows=rows, ) - def self_group(self, against=None): + def self_group( + self, against: Optional[OperatorType] = None + ) -> Union[Self, Grouping[_T]]: if operators.is_precedent(operators.filter_op, against): return Grouping(self) else: return self - @util.memoized_property - def type(self): - return self.func.type + type: TypeEngine[_T] + + if not TYPE_CHECKING: + # A001 is a false-positive here, type is a class member + @util.memoized_property + def type(self) -> TypeEngine[_T]: # noqa: A001 + return self.func.type @util.ro_non_memoized_property def _from_objects(self) -> List[FromClause]: diff --git a/lib/sqlalchemy/sql/functions.py b/lib/sqlalchemy/sql/functions.py index fc23e9d215..a65c03311a 100644 --- a/lib/sqlalchemy/sql/functions.py +++ b/lib/sqlalchemy/sql/functions.py @@ -24,6 +24,7 @@ from typing import Tuple from typing import Type from typing import TYPE_CHECKING from typing import TypeVar +from typing import Union from . import annotation from . import coercions @@ -62,11 +63,13 @@ from .. import util if TYPE_CHECKING: + from ._typing import _ColumnExpressionArgument from ._typing import _TypeEngineArgument from ..engine.base import Connection from ..engine.cursor import CursorResult from ..engine.interfaces import _CoreMultiExecuteParams from ..engine.interfaces import CoreExecuteOptionsParameter + from ..util.typing import Self _T = TypeVar("_T", bound=Any) @@ -449,7 +452,21 @@ class FunctionElement(Executable, ColumnElement[_T], FromClause, Generative): """ return WithinGroup(self, *order_by) - def filter(self, *criterion): + @overload + def filter(self) -> Self: + ... + + @overload + def filter( + self, + __criterion0: _ColumnExpressionArgument[bool], + *criterion: _ColumnExpressionArgument[bool], + ) -> FunctionFilter[_T]: + ... + + def filter( + self, *criterion: _ColumnExpressionArgument[bool] + ) -> Union[Self, FunctionFilter[_T]]: """Produce a FILTER clause against this function. Used against aggregate and window functions, diff --git a/test/typing/plain_files/sql/functions_again.py b/test/typing/plain_files/sql/functions_again.py index edfbd6bb2b..5173d1fe08 100644 --- a/test/typing/plain_files/sql/functions_again.py +++ b/test/typing/plain_files/sql/functions_again.py @@ -21,3 +21,9 @@ func.row_number().over(order_by=[Foo.a.desc(), Foo.b.desc()]) func.row_number().over(partition_by=[Foo.a.desc(), Foo.b.desc()]) func.row_number().over(order_by="a", partition_by=("a", "b")) func.row_number().over(partition_by="a", order_by=("a", "b")) + + +# EXPECTED_TYPE: Function[Any] +reveal_type(func.row_number().filter()) +# EXPECTED_TYPE: FunctionFilter[Any] +reveal_type(func.row_number().filter(Foo.a > 0))