]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Add type annotations for Function.filter
authorMartijn Pieters <mj@zopatista.com>
Wed, 15 Nov 2023 19:04:48 +0000 (19:04 +0000)
committerMartijn Pieters <mj@zopatista.com>
Sat, 18 Nov 2023 15:13:59 +0000 (15:13 +0000)
This includes all methods / properties on the returned FunctionFilter
object.

lib/sqlalchemy/sql/elements.py
lib/sqlalchemy/sql/functions.py
test/typing/plain_files/sql/functions_again.py

index 49505168c08e196f341240e7e3f9e6b7411d3308..0ac66f37cf6de0704d209f253fc2ca5343f94e9a 100644 (file)
@@ -4408,7 +4408,7 @@ class FunctionFilter(ColumnElement[_T]):
         self.func = func
         self.filter(*criterion)
 
-    def filter(self, *criterion):
+    def filter(self, *criterion: _ColumnExpressionArgument[bool]) -> Self:
         """Produce an additional FILTER against the function.
 
         This method adds additional criteria to the initial criteria
@@ -4472,15 +4472,21 @@ class FunctionFilter(ColumnElement[_T]):
             rows=rows,
         )
 
-    def self_group(self, against=None):
+    def self_group(
+        self, against: Optional[OperatorType] = None
+    ) -> Union[Self, Grouping[_T]]:
         if operators.is_precedent(operators.filter_op, against):
             return Grouping(self)
         else:
             return self
 
-    @util.memoized_property
-    def type(self):
-        return self.func.type
+    type: TypeEngine[_T]
+
+    if not TYPE_CHECKING:
+        # A001 is a false-positive here, type is a class member
+        @util.memoized_property
+        def type(self) -> TypeEngine[_T]:  # noqa: A001
+            return self.func.type
 
     @util.ro_non_memoized_property
     def _from_objects(self) -> List[FromClause]:
index fc23e9d2156d9d9d8868abc12072e43e9acba620..a65c03311a22bbed48fa139104f199db900726ce 100644 (file)
@@ -24,6 +24,7 @@ from typing import Tuple
 from typing import Type
 from typing import TYPE_CHECKING
 from typing import TypeVar
+from typing import Union
 
 from . import annotation
 from . import coercions
@@ -62,11 +63,13 @@ from .. import util
 
 
 if TYPE_CHECKING:
+    from ._typing import _ColumnExpressionArgument
     from ._typing import _TypeEngineArgument
     from ..engine.base import Connection
     from ..engine.cursor import CursorResult
     from ..engine.interfaces import _CoreMultiExecuteParams
     from ..engine.interfaces import CoreExecuteOptionsParameter
+    from ..util.typing import Self
 
 _T = TypeVar("_T", bound=Any)
 
@@ -449,7 +452,21 @@ class FunctionElement(Executable, ColumnElement[_T], FromClause, Generative):
         """
         return WithinGroup(self, *order_by)
 
-    def filter(self, *criterion):
+    @overload
+    def filter(self) -> Self:
+        ...
+
+    @overload
+    def filter(
+        self,
+        __criterion0: _ColumnExpressionArgument[bool],
+        *criterion: _ColumnExpressionArgument[bool],
+    ) -> FunctionFilter[_T]:
+        ...
+
+    def filter(
+        self, *criterion: _ColumnExpressionArgument[bool]
+    ) -> Union[Self, FunctionFilter[_T]]:
         """Produce a FILTER clause against this function.
 
         Used against aggregate and window functions,
index edfbd6bb2b1435c7d39ec6b26acafdf5a6485f85..5173d1fe0822347ec88b421b7bce2cc56c6708d7 100644 (file)
@@ -21,3 +21,9 @@ func.row_number().over(order_by=[Foo.a.desc(), Foo.b.desc()])
 func.row_number().over(partition_by=[Foo.a.desc(), Foo.b.desc()])
 func.row_number().over(order_by="a", partition_by=("a", "b"))
 func.row_number().over(partition_by="a", order_by=("a", "b"))
+
+
+# EXPECTED_TYPE: Function[Any]
+reveal_type(func.row_number().filter())
+# EXPECTED_TYPE: FunctionFilter[Any]
+reveal_type(func.row_number().filter(Foo.a > 0))