]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Add type annotations for Function.filter
authorMartijn Pieters <mj@zopatista.com>
Sat, 18 Nov 2023 21:36:08 +0000 (16:36 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 25 Nov 2023 17:22:30 +0000 (12:22 -0500)
This includes all methods / properties on the returned FunctionFilter
object.

This contributes towards #6810

This pull request is:

- [x] A documentation / typographical / small typing error fix
- Good to go, no issue or tests are needed
- [ ] A short code fix
- please include the issue number, and create an issue if none exists, which
  must include a complete example of the issue.  one line code fixes without an
  issue and demonstration will not be accepted.
- Please include: `Fixes: #<issue number>` in the commit message
- please include tests.   one line code fixes without tests will not be accepted.
- [ ] A new feature implementation
- please include the issue number, and create an issue if none exists, which must
  include a complete example of how the feature would look.
- Please include: `Fixes: #<issue number>` in the commit message
- please include tests.

Closes: #10643
Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/10643
Pull-request-sha: 6137b7b995b6ea0bd4e4195c5693d2312fa26639

Change-Id: I2af1af7617d0cd3fd30b262d36ff982464bac011

lib/sqlalchemy/sql/elements.py
lib/sqlalchemy/sql/functions.py
lib/sqlalchemy/sql/operators.py
test/typing/plain_files/sql/functions_again.py

index 5d7c8ab48a8c033d28964f45075103516f12a52a..c4e503b3cf0cf20db3d6dc430d8756da4db69c1b 100644 (file)
@@ -2696,9 +2696,11 @@ class Null(SingletonConstant, roles.ConstExprRole[None], ColumnElement[None]):
     _traverse_internals: _TraverseInternalsType = []
     _singleton: Null
 
-    @util.memoized_property
-    def type(self):
-        return type_api.NULLTYPE
+    if not TYPE_CHECKING:
+
+        @util.memoized_property
+        def type(self) -> TypeEngine[_T]:  # noqa: A001
+            return type_api.NULLTYPE
 
     @classmethod
     def _instance(cls) -> Null:
@@ -2724,9 +2726,11 @@ class False_(
     _traverse_internals: _TraverseInternalsType = []
     _singleton: False_
 
-    @util.memoized_property
-    def type(self):
-        return type_api.BOOLEANTYPE
+    if not TYPE_CHECKING:
+
+        @util.memoized_property
+        def type(self) -> TypeEngine[_T]:  # noqa: A001
+            return type_api.BOOLEANTYPE
 
     def _negate(self) -> True_:
         return True_._singleton
@@ -2752,9 +2756,11 @@ class True_(SingletonConstant, roles.ConstExprRole[bool], ColumnElement[bool]):
     _traverse_internals: _TraverseInternalsType = []
     _singleton: True_
 
-    @util.memoized_property
-    def type(self):
-        return type_api.BOOLEANTYPE
+    if not TYPE_CHECKING:
+
+        @util.memoized_property
+        def type(self) -> TypeEngine[_T]:  # noqa: A001
+            return type_api.BOOLEANTYPE
 
     def _negate(self) -> False_:
         return False_._singleton
@@ -4268,9 +4274,11 @@ class Over(ColumnElement[_T]):
 
         return lower, upper
 
-    @util.memoized_property
-    def type(self):
-        return self.element.type
+    if not TYPE_CHECKING:
+
+        @util.memoized_property
+        def type(self) -> TypeEngine[_T]:  # noqa: A001
+            return self.element.type
 
     @util.ro_non_memoized_property
     def _from_objects(self) -> List[FromClause]:
@@ -4343,13 +4351,15 @@ class WithinGroup(ColumnElement[_T]):
             rows=rows,
         )
 
-    @util.memoized_property
-    def type(self):
-        wgt = self.element.within_group_type(self)
-        if wgt is not None:
-            return wgt
-        else:
-            return self.element.type
+    if not TYPE_CHECKING:
+
+        @util.memoized_property
+        def type(self) -> TypeEngine[_T]:  # noqa: A001
+            wgt = self.element.within_group_type(self)
+            if wgt is not None:
+                return wgt
+            else:
+                return self.element.type
 
     @util.ro_non_memoized_property
     def _from_objects(self) -> List[FromClause]:
@@ -4399,7 +4409,7 @@ class FunctionFilter(ColumnElement[_T]):
         self.func = func
         self.filter(*criterion)
 
-    def filter(self, *criterion):
+    def filter(self, *criterion: _ColumnExpressionArgument[bool]) -> Self:
         """Produce an additional FILTER against the function.
 
         This method adds additional criteria to the initial criteria
@@ -4463,15 +4473,19 @@ class FunctionFilter(ColumnElement[_T]):
             rows=rows,
         )
 
-    def self_group(self, against=None):
+    def self_group(
+        self, against: Optional[OperatorType] = None
+    ) -> Union[Self, Grouping[_T]]:
         if operators.is_precedent(operators.filter_op, against):
             return Grouping(self)
         else:
             return self
 
-    @util.memoized_property
-    def type(self):
-        return self.func.type
+    if not TYPE_CHECKING:
+
+        @util.memoized_property
+        def type(self) -> TypeEngine[_T]:  # noqa: A001
+            return self.func.type
 
     @util.ro_non_memoized_property
     def _from_objects(self) -> List[FromClause]:
index c5eb6b28115f2ba2d4744a312203568c253d6a58..5b54f46ab738c0eb985c1d76f110f450cd2a2efb 100644 (file)
@@ -62,7 +62,6 @@ from .sqltypes import TableValueType
 from .type_api import TypeEngine
 from .visitors import InternalTraversal
 from .. import util
-from ..util.typing import Self
 
 
 if TYPE_CHECKING:
@@ -79,6 +78,7 @@ if TYPE_CHECKING:
     from ..engine.cursor import CursorResult
     from ..engine.interfaces import _CoreMultiExecuteParams
     from ..engine.interfaces import CoreExecuteOptionsParameter
+    from ..util.typing import Self
 
 _T = TypeVar("_T", bound=Any)
 _S = TypeVar("_S", bound=Any)
@@ -484,6 +484,18 @@ class FunctionElement(Executable, ColumnElement[_T], FromClause, Generative):
         """
         return WithinGroup(self, *order_by)
 
+    @overload
+    def filter(self) -> Self:
+        ...
+
+    @overload
+    def filter(
+        self,
+        __criterion0: _ColumnExpressionArgument[bool],
+        *criterion: _ColumnExpressionArgument[bool],
+    ) -> FunctionFilter[_T]:
+        ...
+
     def filter(
         self, *criterion: _ColumnExpressionArgument[bool]
     ) -> Union[Self, FunctionFilter[_T]]:
index 6402d0fd1b2ee9dfaff0102c72e5f6872caf2b2e..1d3f2f483f69e4af2c2888fb4545ed3e2962fc96 100644 (file)
@@ -2582,9 +2582,13 @@ _PRECEDENCE: Dict[OperatorType, int] = {
 }
 
 
-def is_precedent(operator: OperatorType, against: OperatorType) -> bool:
+def is_precedent(
+    operator: OperatorType, against: Optional[OperatorType]
+) -> bool:
     if operator is against and is_natural_self_precedent(operator):
         return False
+    elif against is None:
+        return True
     else:
         return bool(
             _PRECEDENCE.get(
index edfbd6bb2b1435c7d39ec6b26acafdf5a6485f85..5173d1fe0822347ec88b421b7bce2cc56c6708d7 100644 (file)
@@ -21,3 +21,9 @@ func.row_number().over(order_by=[Foo.a.desc(), Foo.b.desc()])
 func.row_number().over(partition_by=[Foo.a.desc(), Foo.b.desc()])
 func.row_number().over(order_by="a", partition_by=("a", "b"))
 func.row_number().over(partition_by="a", order_by=("a", "b"))
+
+
+# EXPECTED_TYPE: Function[Any]
+reveal_type(func.row_number().filter())
+# EXPECTED_TYPE: FunctionFilter[Any]
+reveal_type(func.row_number().filter(Foo.a > 0))